{ "cells": [ { "cell_type": "markdown", "id": "5fbc2d16-59f9-4be3-b93e-1a5440c7efd0", "metadata": {}, "source": [ "# Tutorial 9 - Inverse PDE Problem" ] }, { "cell_type": "code", "execution_count": null, "id": "285b840c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "1afe6a1e-3ab4-4f3f-ad47-f6cd66419504", "metadata": {}, "source": [ "In this tutorial we will be working with the DIY KAN concept once again, this time in order to solve an inverse PDE problem." ] }, { "cell_type": "code", "execution_count": 1, "id": "0a2ef2a6-f681-427f-8252-ade2111ce0e6", "metadata": {}, "outputs": [], "source": [ "from typing import List\n", "from jaxkan.layers.Spline import SplineLayer\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "from jaxkan.pikan.sampling import get_collocs_sobol\n", "\n", "from flax import nnx\n", "import optax\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import os\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" ] }, { "cell_type": "code", "execution_count": null, "id": "f24a5679-f8e3-4ab4-8779-a244dfb02b1f", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "60e70a5d-0340-4bdc-af4e-150d0098a87e", "metadata": {}, "source": [ "## Data Generation" ] }, { "cell_type": "markdown", "id": "166ad90d-430e-45ca-86a8-e6e4dbc1c943", "metadata": {}, "source": [ "For the purposes of this example, we will be working with the Diffusion Equation,\n", "\n", "$$ \\frac{\\partial u}{\\partial t} -D \\frac{\\partial^2 u}{\\partial x^2} = 0,$$\n", "\n", "in the $\\Omega = [0,1]\\times [0,1]$ domain, subject to the boundary conditions\n", "\n", "$$ u\\left(t=0, x\\right) = \\sin\\left(\\pi x\\right), $$\n", "\n", "$$ u\\left(t, x=0\\right) = u\\left(t, x=1\\right) = 0. $$\n", "\n", "We have intentionally left $D$ undefined, as we intend to also estimate it (apart from solving the PDE), using \"experimental data\". The PDE's analytical solution is given by\n", "\n", "$$ u(t,x) = \\sin\\left(\\pi x\\right) \\cdot \\exp\\left(-D\\pi^2 t\\right), $$\n", "\n", "so it will be used to generate mock experimental data with gaussian noise for $D = 0.25$." ] }, { "cell_type": "code", "execution_count": 2, "id": "b986e75a-6d4a-402f-bea7-36d13f4a7866", "metadata": {}, "outputs": [], "source": [ "seed = 42\n", "\n", "# Generate Collocation points for PDE\n", "pde_collocs = get_collocs_sobol(ranges=[(0,1), (0,1)], total_points=2**12, seed=seed)\n", "\n", "# Generate Collocation points for IC\n", "ic_collocs = get_collocs_sobol(ranges=[(0,0), (0,1)], total_points=2**6, seed=seed)\n", "ic_data = jnp.sin(np.pi*ic_collocs[:,1]).reshape(-1,1)\n", "\n", "# Generate Collocation points for BCs\n", "bc1_collocs = get_collocs_sobol(ranges=[(0,1), (0,0)], total_points=2**6, seed=seed)\n", "bc1_data = jnp.zeros(bc1_collocs.shape[0]).reshape(-1,1)\n", "\n", "bc2_collocs = get_collocs_sobol(ranges=[(0,1), (1,1)], total_points=2**6, seed=seed)\n", "bc2_data = jnp.zeros(bc2_collocs.shape[0]).reshape(-1,1)\n", "\n", "# Concatenate IC/BCs\n", "bc_collocs = jnp.concatenate([ic_collocs, bc1_collocs, bc2_collocs], axis=0)\n", "bc_data = jnp.concatenate([ic_data, bc1_data, bc2_data], axis=0)" ] }, { "cell_type": "code", "execution_count": 3, "id": "6981f02b-a9d8-4afd-9764-d78dae055a6f", "metadata": {}, "outputs": [], "source": [ "# Generate experimental data for inverse problem\n", "def u(t, x, tau):\n", " return jnp.sin(jnp.pi*x)*jnp.exp(-tau*(jnp.pi**2)*t)\n", "\n", "key = jax.random.PRNGKey(seed)\n", "idxs = jax.random.choice(key, jnp.arange(pde_collocs.shape[0]), (1000,), replace=False)\n", "exp_collocs = pde_collocs[idxs]\n", "\n", "u_vals = u(exp_collocs[:,0], exp_collocs[:,1], 0.25).reshape(-1,1)\n", "noise = u_vals.std()*jax.random.normal(key, shape=(1000,1))\n", "exp_data = u_vals + noise" ] }, { "cell_type": "code", "execution_count": null, "id": "21e57a59-fa44-42ad-8d18-1e4983072483", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "4e9d9bba-71e2-4d5b-95b1-36167fb70c1a", "metadata": {}, "source": [ "## KAN Model" ] }, { "cell_type": "markdown", "id": "2fd54c6c-3d28-4864-af38-3b22875d0f0a", "metadata": {}, "source": [ "We will define a KAN Class based on the Spline Layer, which will also include a trainable parameter, $\\tau$." ] }, { "cell_type": "code", "execution_count": 4, "id": "80600c7a-09d3-4575-a586-b587bbdf6fb0", "metadata": {}, "outputs": [], "source": [ "class MyKAN(nnx.Module):\n", " \n", " def __init__(self, layer_dims: List[int], k: int = 3, G: int = 5, add_bias: bool = True, seed: int = 42):\n", " \n", " self.layers = nnx.List([\n", " SplineLayer(\n", " n_in=layer_dims[i], \n", " n_out=layer_dims[i + 1],\n", " k=k,\n", " G=G,\n", " residual=nnx.silu,\n", " external_weights=True,\n", " add_bias=add_bias,\n", " seed=seed)\n", " for i in range(len(layer_dims) - 1)\n", " ])\n", "\n", " # This is the parameter we need to identify to solve the inverse problem\n", " # We initialize it at 1.0\n", " self.tau = nnx.Param(jnp.array([1.0]))\n", "\n", " \n", " def __call__(self, x):\n", "\n", " for layer in self.layers:\n", " x = layer(x)\n", "\n", " return x" ] }, { "cell_type": "code", "execution_count": 5, "id": "b350b56f-5daa-411f-9090-b544760c34ef", "metadata": {}, "outputs": [], "source": [ "# Initialize a MyKAN model instance\n", "n_in = pde_collocs.shape[1]\n", "n_out = 1\n", "n_hidden = 6\n", "\n", "layer_dims = [n_in, n_hidden, n_hidden, n_out]\n", "\n", "model = MyKAN(layer_dims = layer_dims, k = 3, G = 5, add_bias = True, seed = 42)" ] }, { "cell_type": "code", "execution_count": null, "id": "09778d14-8b72-4449-95fc-fc8967006935", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "f681d135-c885-4e59-87d7-1ea16a157c50", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "markdown", "id": "ca097309-be81-4959-86c9-2365faab1362", "metadata": {}, "source": [ "PIKANs provide a unified framework for solving the forward and the inverse PDE problem. Nonetheless, we will need to incorporate the \"experimental\" data in the loss function." ] }, { "cell_type": "code", "execution_count": 6, "id": "6723a412-c313-453c-aa88-9274687ee54e", "metadata": {}, "outputs": [], "source": [ "opt_type = optax.adam(learning_rate=0.001)\n", "\n", "optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)" ] }, { "cell_type": "code", "execution_count": 7, "id": "f18ab849-3c05-418a-a30b-3928f77c332d", "metadata": {}, "outputs": [], "source": [ "# PDE Loss\n", "def pde_loss(model, collocs):\n", " tau = model.tau[0]\n", "\n", " def u_fn(t, x):\n", " return model(jnp.array([[t, x]]))[0, 0]\n", "\n", " u_t_fn = jax.grad(u_fn, argnums=0)\n", " u_x_fn = jax.grad(u_fn, argnums=1)\n", " u_xx_fn = jax.grad(u_x_fn, argnums=1)\n", "\n", " pde_res = jax.vmap(\n", " lambda t, x: u_t_fn(t, x) - tau * u_xx_fn(t, x),\n", " in_axes=(0, 0),\n", " )(collocs[:, 0], collocs[:, 1]).reshape(-1, 1)\n", "\n", " return pde_res\n", "\n", "# Define train loop\n", "@nnx.jit\n", "def train_step(model, optimizer, collocs, bc_collocs, bc_data, exp_collocs, exp_data):\n", "\n", " def loss_fn(model):\n", " # PDE part\n", " pde_res = pde_loss(model, collocs)\n", " total_loss = jnp.mean((pde_res)**2)\n", "\n", " # IC/BC part\n", " bc_res = model(bc_collocs) - bc_data\n", " total_loss += jnp.mean(bc_res**2)\n", "\n", " # Experimental data loss\n", " exp_res = model(exp_collocs) - exp_data\n", " total_loss += jnp.mean(exp_res**2)\n", "\n", " return total_loss\n", " \n", " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", " optimizer.update(model, grads)\n", " \n", " return loss" ] }, { "cell_type": "code", "execution_count": 8, "id": "b06282a5-8899-4801-9c9e-d8b73b55d74a", "metadata": {}, "outputs": [], "source": [ "# Initialize train_losses\n", "num_epochs = 5000\n", "\n", "for epoch in range(num_epochs):\n", " # Calculate the loss\n", " loss = train_step(model, optimizer, pde_collocs, bc_collocs, bc_data, exp_collocs, exp_data)" ] }, { "cell_type": "code", "execution_count": null, "id": "96c368c8-cfb6-4dde-abba-da6c03c6fe77", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "faab949e-dadb-4cec-bc13-63b10cb609c5", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "markdown", "id": "5ef7d214-6efa-4124-a587-b47e98cac7d6", "metadata": {}, "source": [ "The following plot shows the trained neural network on the entire domain, approximating the solution, $u$, of the equation." ] }, { "cell_type": "code", "execution_count": 9, "id": "7798eb68-d6b5-47ec-b845-cf3d60dccf23", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoQAAAGGCAYAAADil5DZAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAfjNJREFUeJztvXucFNW57v9U90zPoAhIEEYQwUsSL6gkIAjGjeZg8BKMOfGS6EbEqEcjRp0kCt6QuAUvCQe3omzvxMRINNHtTwiKIJ9EJSFRydYoJCp4O84gSWQQdC7d6/fHOM1019vM27VW1VpV9X799B+UtapWVfXlmee9LE8ppSAIgiAIgiCkloztCQiCIAiCIAh2EUEoCIIgCIKQckQQCoIgCIIgpBwRhIIgCIIgCClHBKEgCIIgCELKEUEoCIIgCIKQckQQCoIgCIIgpBwRhIIgCIIgCClHBKEgCIIgCELKEUEoCBocffTROProo40ec+PGjfA8Dw888IDR45pm2bJlGDlyJOrr6+F5Hj766CNjx77uuuvgeV7Jto6ODlx++eUYOnQoMpkMTj75ZADAxx9/jHPPPRcNDQ3wPA+XXnqpsXl0MXz4cJx99tnGj5t0HnjgAXieh40bN9qeiiAIPSCCUEgVr7zyCk455RQMGzYM9fX1GDJkCI499ljcdtttkc/loYcewvz58yM/rwn+8Y9/4LTTTkOvXr2wYMECPPjgg9h1113JfbtEQdervr4egwcPxqRJk/Cf//mf2Lp1K+uc9913H2655RaccsopWLRoES677DIAwJw5c/DAAw/gwgsvxIMPPogpU6YYu07blN+78tcf/vAH21ME0PkMHn/8cdvTEARBA0/WMhbSwgsvvIBjjjkGe++9N6ZOnYqGhga8++67+MMf/oA333wTb7zxRtXH7HIHV61aVfXYr3/963j11Vd97olSCq2traitrUU2m636uFGwbNkyHH/88Vi+fDkmTpy4030feOABTJs2DT/+8Y+xzz77oL29HU1NTVi1ahWWL1+OvffeG0888QQOPfTQ4piOjg50dHSgvr6+uO3b3/42nnvuObz33nslxz/iiCNQU1OD5557zuxFdqO1tRWZTAa1tbWhnYOi/N6Vc9xxx2HAgAGRzomid+/eOOWUU3yudj6fR3t7O+rq6nyOryAIblFjewKCEBU33HAD+vbtiz/96U/o169fyf/btGmTnUkRdLloLtN1v8rv4844/vjjMXr06OK/Z86ciZUrV+LrX/86TjrpJLz++uvo1asXAKCmpgY1NaVfT5s2bSLPt2nTJhx00EHVX0QV1NXVhXr8nii/d3Ehm806+0eNIAilSMhYSA1vvvkmDj74YFJUDBw4sOTfHR0duP7667Hffvuhrq4Ow4cPx5VXXonW1tadnqNSztSqVavgeV7RSTz66KOxZMkSvP3228Xw3/DhwwFUziFcuXIljjrqKOy6667o168fvvGNb+D1118v2acr9+6NN97A2WefjX79+qFv376YNm0atm/f3uM9AoBHHnkEo0aNQq9evTBgwAD8+7//O95///3i/z/66KMxdepUAMDhhx8Oz/MC59d99atfxTXXXIO3334bP//5z33XAey4H88++yz++te/Fu9X1z3dsGEDlixZUty+ceNG9nMAgL///e/41re+hYaGBtTX12OvvfbCt7/9bWzZsqW4D5VD+NZbb+HUU09F//79scsuu+CII47AkiVLyPP96le/wg033IC99toL9fX1+F//638FcqR3xkcffYSzzz4bffv2Rb9+/TB16lSsXbvW916qlPd69tlnF9+DXfzkJz/B+PHj8bnPfQ69evXCqFGj8Oijj5bs43ketm3bhkWLFhWfQde9qvQc7rjjDhx88MGoq6vD4MGDcdFFF/lyUI8++miMGDECr732Go455hjssssuGDJkCG6++eaAd0gQhJ0hDqGQGoYNG4bVq1fj1VdfxYgRI3a677nnnotFixbhlFNOwQ9+8AP88Y9/xNy5c/H666/jscce057LVVddhS1btuC9997D//2//xdAZ9itEs888wyOP/547LvvvrjuuuvwySef4LbbbsORRx6Jl156yfdDftppp2GfffbB3Llz8dJLL+Gee+7BwIEDcdNNN+10Xl0hysMPPxxz585Fc3Mzbr31Vjz//PN4+eWX0a9fP1x11VX44he/iLvuuqsYytxvv/0C34spU6bgyiuvxNNPP43zzjvP9//32GMPPPjgg7jhhhvw8ccfY+7cuQCAAw88EA8++CAuu+wy7LXXXvjBD35Q3J9LW1sbJk2ahNbWVlx88cVoaGjA+++/jyeffBIfffQR+vbtS45rbm7G+PHjsX37dnz/+9/H5z73OSxatAgnnXQSHn30UXzzm98s2f/GG29EJpPBD3/4Q2zZsgU333wzzjzzTPzxj39kzXPLli3YvHlzyTbP8/C5z30OQGeawTe+8Q0899xzuOCCC3DggQfiscceKwr3oNx666046aSTcOaZZ6KtrQ0PP/wwTj31VDz55JM48cQTAQAPPvggzj33XIwZMwbnn38+AOz0/XDddddh9uzZmDhxIi688EKsX78ed955J/70pz/h+eefLwnL/+tf/8Jxxx2H//2//zdOO+00PProo7jiiitwyCGH4Pjjj9e6NkEQylCCkBKefvpplc1mVTabVePGjVOXX365euqpp1RbW1vJfmvXrlUA1Lnnnluy/Yc//KECoFauXFncNmHCBDVhwoTiv++//34FQG3YsKFk7LPPPqsAqGeffba47cQTT1TDhg3zzXPDhg0KgLr//vuL20aOHKkGDhyo/vGPfxS3/eUvf1GZTEadddZZxW2zZs1SANQ555xTcsxvfvOb6nOf+1ylW6OUUqqtrU0NHDhQjRgxQn3yySfF7U8++aQCoK699lrfdf7pT3/a6TG5+/bt21d96Utf8l1HdyZMmKAOPvhg39hhw4apE088kTxnT8/h5ZdfVgDUI488stNrGDZsmJo6dWrx35deeqkCoH7/+98Xt23dulXts88+avjw4Sqfz5ec78ADD1Stra3FfW+99VYFQL3yyis7PW/XdVCvurq64n6PP/64AqBuvvnm4raOjg511FFH+d5L5e/ZLqZOnep7P27fvr3k321tbWrEiBHqq1/9asn2XXfdteT+lM+/6zls2rRJ5XI59bWvfa14j5RS6vbbb1cA1H333VcyTwDqZz/7WXFba2uramhoUN/61rd85xIEQQ8JGQup4dhjj8Xq1atx0kkn4S9/+QtuvvlmTJo0CUOGDMETTzxR3G/p0qUAgMbGxpLxXQ5UeVgwbD744AOsXbsWZ599Nvr371/cfuihh+LYY48tzrc7F1xwQcm/jzrqKPzjH/9AS0tLxfP8+c9/xqZNm/C9732vJIfxxBNPxAEHHBDqdffu3ZtdbWySLgfwqaeeYofUgc73yJgxY/CVr3yluK137944//zzsXHjRrz22msl+0+bNg25XK7476OOOgpAZ9iZw4IFC7B8+fKS129/+9uS+dTU1ODCCy8sbstms7j44ovZ10TRldMJdLp1W7ZswVFHHYWXXnop0PGeeeYZtLW14dJLL0Ums+Pn57zzzkOfPn1877HevXvj3//934v/zuVyGDNmDPu+CYLARwShkCoOP/xw/OY3v8G//vUvrFmzBjNnzsTWrVtxyimnFH/E3377bWQyGey///4lYxsaGtCvXz+8/fbbkc6563xf/OIXff/vwAMPxObNm7Ft27aS7XvvvXfJv3fffXcAnT/qQc5zwAEHhHrdH3/8MXbbbbfQjl+JffbZB42NjbjnnnswYMAATJo0CQsWLCjJH6R4++23Kz6Prv/fnSDPoztjxozBxIkTS17HHHNMyXz23HNPX9oBNcdqePLJJ3HEEUegvr4e/fv3xx577IE777yzx/tTiUrvsVwuh3333dd33/baay9fdfLuu+/Ovm+CIPARQSikklwuh8MPPxxz5szBnXfeifb2djzyyCMl+wRpk1FpTD6fDzTPoFSq7FQOdpl67733sGXLFp8A16Ga5/DTn/4U//M//4Mrr7wSn3zyCb7//e/j4IMP9rW30cGl58G9N7///e9x0kknob6+HnfccQeWLl2K5cuX44wzzohs3i7dN0FIOiIIhdTT1c7jgw8+ANBZfFIoFPD3v/+9ZL/m5mZ89NFHGDZsWMVjdTk/5RWTlLvGFZxd51u/fr3v/61btw4DBgyo2BS6GnZ2nvXr1+/0unV48MEHAQCTJk0ydsxqngMAHHLIIbj66qvxu9/9Dr///e/x/vvvY+HChRWPP2zYsIrPo+v/R8mwYcPwwQcf4OOPPy7ZTs1x9913J1eVKb83v/71r1FfX4+nnnoK55xzDo4//viKPSd138ttbW3YsGFD5PdNEIQdiCAUUsOzzz5LOgtdOXhdYawTTjgBAHyriMybNw8AitWVFF3Vlb/73e+K2/L5PO666y7fvrvuuisr9Lbnnnti5MiRWLRoUckP+auvvoqnn366OF9dRo8ejYEDB2LhwoUl7XV++9vf4vXXX9/pdQdl5cqVuP7667HPPvvgzDPPNHZc7nNoaWlBR0dHybZDDjkEmUxmpy2GTjjhBKxZswarV68ubtu2bRvuuusuDB8+PPS+iNR8Ojo6cOeddxa35fN5cgWe/fbbD+vWrcOHH35Y3PaXv/wFzz//fMl+2WwWnueVOIcbN24kVyTZddddWUsXTpw4EblcDv/5n/9Z8lm89957sWXLllDeY4Ig8JC2M0JquPjii7F9+3Z885vfxAEHHIC2tja88MILWLx4MYYPH45p06YBAA477DBMnToVd911Fz766CNMmDABa9aswaJFi3DyySeX5G6Vc/DBB+OII47AzJkz8c9//hP9+/fHww8/7BMdADBq1CgsXrwYjY2NOPzww9G7d29MnjyZPO4tt9yC448/HuPGjcN3v/vdYtuZvn374rrrrjNyf2pra3HTTTdh2rRpmDBhAr7zne8U284MHz68uFRcUH77299i3bp16OjoQHNzM1auXInly5dj2LBheOKJJ4w24+Y+h5UrV2L69Ok49dRT8YUvfAEdHR148MEHkc1m8a1vfavi8WfMmIFf/vKXOP744/H9738f/fv3x6JFi7Bhwwb8+te/LimYMEHXvStn/Pjx2HfffTF58mQceeSRmDFjBjZu3IiDDjoIv/nNb8g/OM455xzMmzcPkyZNwne/+11s2rQJCxcuxMEHH1xSdHTiiSdi3rx5OO6443DGGWdg06ZNWLBgAfbff3/8z//8T8kxR40ahWeeeQbz5s3D4MGDsc8++2Ds2LG+c++xxx6YOXMmZs+ejeOOOw4nnXQS1q9fjzvuuAOHH354SQGJIAgRY7PEWRCi5Le//a0655xz1AEHHKB69+6tcrmc2n///dXFF1+smpubS/Ztb29Xs2fPVvvss4+qra1VQ4cOVTNnzlSffvppyX5UC48333xTTZw4UdXV1alBgwapK6+8Ui1fvtzXdubjjz9WZ5xxhurXr58CUGz5QbWdUUqpZ555Rh155JGqV69eqk+fPmry5MnqtddeK9mnq13Lhx9+WLK9UhsWisWLF6svfelLqq6uTvXv31+deeaZ6r333iOPV03bma5XLpdTDQ0N6thjj1W33nqramlp8Y3RbTujFO85vPXWW+qcc85R++23n6qvr1f9+/dXxxxzjHrmmWd85yhvq/Lmm2+qU045RfXr10/V19erMWPGqCeffLJkn662M+VtbSo943J21namfPw//vEPNWXKFNWnTx/Vt29fNWXKlGJbnfLz/PznP1f77ruvyuVyauTIkeqpp54i287ce++96vOf/7yqq6tTBxxwgLr//vvJZ7Nu3Tr1b//2b6pXr14KQPFeVXrf3X777eqAAw5QtbW1atCgQerCCy9U//rXv0r2qfS8qXkKgqCPrGUsCIKQUDZu3Ih99tkH999/f+DVZARBSAeSQygIgiAIgpByRBAKgiAIgiCkHBGEgiAIgiAIKceqIPzd736HyZMnY/DgwfA8j2xnUM6qVavw5S9/GXV1ddh///3xwAMPhD5PQRCEODJ8+HAopSR/UBBihC1tZFUQbtu2DYcddhgWLFjA2n/Dhg048cQTccwxx2Dt2rW49NJLce655+Kpp54KeaaCIAiCIAjhY0sbOVNl7HkeHnvsMZx88skV97niiiuwZMkSvPrqq8Vt3/72t/HRRx9h2bJlEcxSEARBEAQhGqLURrFqTL169Wrf0kmTJk3CpZdeyj5GoVDA//t//w+77bZboLVqBUEQBEGIFqUUtm7disGDBxtv/M7h008/RVtbW6CxuVzOaOP9ckxoIyBmgrCpqQmDBg0q2TZo0CC0tLTgk08+Qa9evXxjWltbS5agev/99yNfVkoQBEEQBH3effdd7LXXXpGe89NPP0VDr77YgmCCsKGhAX/5y19KRGFdXR3q6uqMzC+INqKIlSAMwty5czF79mzf9m8dOx+1tbybJAguUMg47Gg71K+gkHX4Pmng9PPXJFNwInPJGraerc55VcAIG/X55Myjvf0TPP2rC7HbbrsFOq8ObW1t2II2/ATj0atK2fQJOvDDphd8gm3WrFnGlh01RawEYUNDA5qbm0u2NTc3o0+fPhUV8MyZM9HY2Fj8d0tLC4YOHQqv967wancJdb5ByeST+eXopfxLHwCUxhdw1uA8TOOSCOPep7gJLJefv6BHFM+Wer/rnJfzXUZ9L1DnrOazaDPVa9dMLXp51cmmjPKAQqez2adPn+J2U+4gEEwbUcRKEI4bNw5Lly4t2bZ8+XKMGzeu4phKtmxrfQ0KuZ1ffhTCjBRJETwVV/4id2UeQvwEUhjIPXAb+b7wE8V7VucP2aAUyvIEy/9tg0wWqPZWZBSAAtCnT58SQWiSINqIwqog/Pjjj/HGG28U/71hwwasXbsW/fv3x957742ZM2fi/fffx89+9jMAwAUXXIDbb78dl19+Oc455xysXLkSv/rVr7BkyZKqz92eqwF6EISmHa1MoWD4eMHnl6eOZ8GZLKTsC17nmYlYcRfTP5jc7x7qvEl24uUzED4672VOpCDOz9DLeMhU6VB6qvrrtaWNrArCP//5zzjmmGOK/+4K7U6dOhUPPPAAPvjgA7zzzjvF/7/PPvtgyZIluOyyy3Drrbdir732wj333INJkyZVfe6OuiyQ27lhrvPFSv3w5zUSrSixRok6CvZ1MN4NLola00QhiPMx/jIUokPnR9mGm2OTJAtgk5h+X7iUJhIV2SxQ7WVnA7w9bWkjZ/oQRkVLSwv69u2Lf5u+GDV1O88h1BErXHGhcw7TgpU1zrBoct2FNY1LAlgQgpLUPGcu1PdWFG6tS2I/qCDkOoTlIeL2tu347YNnYcuWLaGFXivRpRvu730Mdqkyh3C76sC0j5+1Mu9qiVUOoUnac1moHhxC7o836QYSd5b75RCFECVDxozzFjKmBZzZa81nzKZnm/5CL2ikwbgudoXosfUHhi13yBUhyhVmLgk4HaIWf66SyQTMIYwJ8XgKgiAIgiAIQmik1iFsq88iX7/j8qm/tLl/fbOdP8NhZNPOZNC56PzVrrSKLCJwKw07ITr3yqT7KXlXdjDt8rqeoG/awbThTLriSoaB6fvp+vtRl0y2+qKSTICiElukVhDm67JA3Y4f2LyOINQQenTxCXEOS6KTI9h0hBlb/FLXxfzy4Z6DW/ChI6a4X5ihhwKZPwRJ/jHUIeh7wPXQWBwFa9iflTQWTwg0Ga8zbFzVmBhl+qRWENbW5VFTt0N6FQpE9/Q874sgTyWGMYUe9WWmIzAph5Daj9vuhSMItBxNdo6JJdFJ7hi+qKPui7h6ZtFyth12QvT+YDErWKPIew0qOqXAS4+g9537Hiv/jLnwmctkPWSqnEe1jqJNUisIc7meBSEFJRJJMckWmP5tpgUm6TgSP4Ycgem6o+mS6OQ6CyaFqOkfuShc0yjQeT42cF2scp93FI5oUNGZBPdSl7DvgeuOeLVkM52vqsaEM5VQSK8grCuUCkKm0KPumI4gZAtRDYFJnoMQP6zwsOmQNFOEmXQ0gQrzI74cdcLN/B/N4Of1HctWKNjgNehC3XdXcpu49yQKsWpLdJqv3A8mOlx2LwU3yWQCOITEb7KrJEu+C4IgCIIgCFWTWoewvr4Dtb06iv/O64SCNdxA7jYKnXMEPh71jqGatHKdkAgcR9IxMuykcR1H7lhOqNZ0cUsUFaFRFKm48leuLafSlgvJfba2lvkrJ475kq4T9J66kB/IwctWX1QSoxTC9ArCmpoCamoK3f7t36ejw//kucIxitAy9eOiJf4CHo+eL/NTQwg4UiARP15UHqSO0OEKM67A5P5ABp6zYVFn+ucsirw96h67IrpsCR/T1x+FwLQRvnYldF0JlwRm0nIBg5LJIEDIOD6kVhDW98qjttfOi0oyxCKEOq4c24U0LfQ0zsGptDY/X8fFJHM/rbxCG+6aYYGZlErpsEUn3/12W2Bq/TFiIffVpTxIiqSKsLi4gRSdgrDKMeFMJRRSKwhrc3nkuhWVUG4g9ZcAP+zrP6dLYjJoEY1plzNDiDq2MKWKsZkV2mxRZ0lMUnD8AlthWgrqukz/COsIMxuhVdPiUkeEcZ9FFM5k2M/C9GfA9L1LKjqf9/Jn7UJXgEzGQ6bKecSpqCS1gjBXVyoIa2r9P7cd7X5xEU0OoX++kYhJxjYdB5KCfe8MH48eGz8xyZqHJafSeJ6i4yHycuLW6gaIRsRThC0m4xbOroRLAjPOTl9QArWdceeR9Uic3ExBEARBEAQhBFLrEJatXIc8lY+W8bdc1nEDqbA033HzbSJ781HH44a+OU6iydzDimM1zkGGoJnH08oZpf5ajsBdDIrOX4K2Ut1Jx03HWdKZTBmuh7N1XF3yeIbzHsNeDjJu+Y2VSKor57pz3kUmKyHjRFJfJgip7wvquVPCsZ2ZB8fOl6PEJDd8HbJwNJ3LaFqEmT4HBfd4dI5jcDHJETBcsaZVjc08B3lejbE6GBeTZcezlRuZNuFIwRGTSS2MEaIlUFFJjB53ugVht6un8va4pk+OeIO0kflDhBNEHI8Sfzpiku9g9jxWx22kiKIwxngxC4FxcUqtOFN+DgtFK+Q8qjhHHMWkSXRy9KJwIU3n2kXhknLuXxwrqrm4LDDj4vxxCbRSiYrPPUitIKzLdorCLihnLUMVdxDb2oht1OeA2o/66yGb8+/IFY5aThpjP45o5B6rEmyxZlg46oSbKbjXQY7lCEfDDmQUYlIHW2IyqNDRcSVdqtBOggtpq6I6iuUbbRXMJE3scfCyQKbKxYnjdJdSKwh7ZRXqu/3oUp+LDNFinCsc2/3ph2zHUUc4tlPix2gIlhChEeQ3ctERdVw3kH3eCIRjUCgHkt6PB1esRSIcTbtwjB8+HXfMtHCkiMKFdEU4pq0Vjw4ui7rye+fCutDiECaUcoeQ+u7Jev6NpFgjhCOZf0icgysc24nz0iI2eI5j0Lw61wtjuOiEkSNxDcvGhi0aAXvCUeccOnDnVy4IdESOCEeaoPc0CveOuk/mVz5xx4UMGxfEHodMAIcwTq1c4jRXQRAEQRAEIQRS6xDuWgP06nb13DxAOuzr/wuNGttOODrkfoRryHYcmdcRtFqa63AlocVOJUw7iUHD3DrV06bD1MadRCrsy3SMdM7BdmWY5/AdP+C4agi7ohqIn5No2pmlsLV6ieth6XLi4gZSZD0P2Srnn40gkmOK1ArCukxpyJh6xlxBSIkw+nhUCJonEknByvx1oTqlc3OJC2VzpoQkVdxSU+OfnPHqYceFIwW7ZY0juC4cKUznMwYVWFHkMkYhdHSqynVCqyavw/SKO0HnAbi17F2cxZkNArWdiVEcNr2CMAvUdftx5uYBcrdxnTqdPMV2SrASbz4yT5H5XVP+nelOcQt9PMqZBIJ/sXIFkc7ygEHn4nLRii5c4UhVZZl2+TjCMYqK4qAFL4BbuYthF264UtxSzVwoXBKOQieBGlNLUYn75LIF1JVYbNTPA08kct1FSuhR0C6knbB0ubClvhup66qlei5qhFppoeenhvj5Nu0umsZ0u5s0YVo4Bg1VuxSmjkI4uuwuRhGStuUuUohwjA5xCBNKTUahpuSH2P+VTrpy5I8y7wNp2l2k9qPcMO5fKBzXkHQImeKXGtvOdO/4Qs9sWLqDmAsV9tVpJq5TGe07p6WQtOuh8LBzHCnRSC01aNqVi2NYOmx3USckHYVIND0XLi4LR9bcHAhvexkFr8qlR6rd3yYx0q6CIAiCIAhCGKTWIdylRmGXmh3KvZXZD5BffEI4UGTfQN7xbIWbfeOY/RAp6F6PxDyIvMoows0UNcQnhDpvTS3hYBp2yDg5hMabeqcsTG3SSWSHeDXCtFHk8rFdU0ecRFuFLC6v1FINOiviJB0v0/mqdkxcSK0g5OQQ0oKQElfEl7fhcDMFP4xMjCW/Q4hwc9l1kIUszO8jcj9qvtRqMDEMN7u0gov/+MkVfzprSLPvC2NdaZM5ioD5PEXTuXymc+2C5inaKGSpdF5ybAThZookNLC2jecpeIQG6GlMXEivIPQUcgFyCLnVyKbdRR3xR8F39cp3JCZCzYPq5k64sJT4o8Qq6ZBSgovpLlI5f9FAfXmH+4WRFPHHFXA2sNWH0SV30ZViFpfa5NiqeCbPIcJRG3EIE0ptRqG2hx8Tt5xzs04iOwTtu0XUPSMELPM7xXQImhxLXSsV4o0gBE1jTiRGIf6SIjApgrqLgZ1FwLi76FIxC1tMaRSzJFUkkscz7C6S5xDhWBEvo6r+IzRORSWpFYTlIWPKlUOesq/8m3SaWlMhaDp83XM4t/O8/m1c8UfBE3aEK0eO88+NTLNjOok6C+gmIQRtq00Od+UTl0iVu8jcz7S76ErFs8stcSqdl8JWCJpL1E2tC9RvdMR4XgCH0P602cTIzBQEQRAEQRDCILUOYV1Wla1UQu3l/5ucWpeQdBf5f6ezztuhFZILPjboH5rcohVQ10W5sKYLV8jBxFDCmW1nHi58os9HBJIdMqbgOKJWClkAa+FmlwtXogiPU+i4hhSu5ySmkaT3IUytIMx6tch6tcV/exkqTklVFBO7UUOJ5DgqnEsLxyh+6HlfGNxcwHKoMHWO1MhEKJxqpG2rcIWAuwqLncIVOyLRFjrha66IC4qIRDsi0XRls8vL9FU6LxcRjtUhRSUJJevVIOvt/PJzTJFIws2D06huNg+R42fweyB4ZTOgU91MFa5wBSt3qT4SZuFK+MLRbZGYhJxE0/NNm0ikCFq44kr7m2rGulS4Qp7DUp6ia8jSdQklk6lBNrPDIaTEhUeIsByxmFnW4/VEyWgUd3Ax3QKH06yaC/s3kxRrwaub24iROWIbV+hxhSPdY9G/I7epdfmPP/VDoLOEnkvrO+ugU0AStji1JX7JwpWE9E4sF1OutL/RHetS4Qp5DuY99Z0zxkJS+hAmFJ9DyK1iJffzi8Q6ypZi5hW6/QdFcDHA7SVIwhSJbcQPLiXguCKRLfQo2O8pc18YpLg0nn8any+4qNAJP3NFovEQdEJ6JxpdqSTk9je6Y13KUwx8zoBCMug4kyQ9ZByjqQqCIAiCIAhhkGKHsLSohMQh1zCCaLMGdlxD2lgz6xrynzc1WGOs0zl04buGecN9A6Nw8GycUwpXeH0IxTXkY8M1jAtSZZxQMl4WGY8qU+1pILHNkkh0m+AikZ1iQtw66vfMdGhZRCJFukWiK5XN1ZwjCpHIxaRIjGJlFS5JFokUSReOSQ8Zp1YQduYQ7twhVIqqNCF2ZP6gK2ZyaY3TP/w6MEUice/I9BGH8g+tiETn3yfxyz8MKtiiaH9jvFgmApFoI/8wiuX30lbJzCXoqjFxobPKuLprkirjGMBpO8MWetSPnIYY0Gp3EzLUd16t1hs+uUUqoYvE2LmIgGmRSAmTrCXHjTUuAgFHYS3c7DAiEsMXiUlzFsUhTCgZZJHp3rSO+13G/UEn3t+KbE/DozYTfH0M8rPGfJMWyppEU+KvnTgB5ehxI0+1GmtkUiKR/J1iPscoRGLgLi6JCDUDSRCJYecjVhzrek5iBO1uyl3DKMRV2kQiRdTCMeq1kyk8BGg744iZwyG1ghCFjpKcvkzWfyso548KI1NCj3QNmbDHMkUiv0ceNZYzF7M/6O1kyJhSk4FPoeXeRdHuht38ugy2uBSRSBI0J9FG0UpU5zXtVoZdpKJToKJDkkUiRRQhaNdIukMYo6kKgiAIgiAIYZBeh7CjHeggg3pFsoRrqJUDRqCIwVyHkBpbS+Qfsv9oo8KtgZ2kCAoKCOeCMinYBSka4dwsGUYgQpfMuQR1DUlXkp2j6Ppf9/GqZBbXkH8Ok65hEtZjBpLhGibNMZS2M0kl3wbke7h8wuulWtVwQ8Y6YWQKShBSUCKRiy9krOEpl+cjdh7Pf0/yxH60kOKJMHoy1PF4Q7nrILdRX4bkDx9vLkHT4LQa/Dv/ZZZukUihU/DBFn8RnCOoSHRlxRRdbK0fbPLakhZW9gKsZRynkHF6BWFHG9DR7fKpp5b3O4he1u/d9Njgukp0XEMK0r0iFlamBFtt+Zc3s2iD2lZDunKMc1aCWo83AoeQ68LRQoz4caHyIynK/xZh6nwt15BCw6mKAvoLO3yR6J9H+EKPezz2WI2K4ijOwRGJrjTIrnhex5tkk+cNWSRSuCgcxSFMKKrQAVXYUZThdTBlPCEcKZFIOomka0iIP6pwhfpBIz5XdNEL7w0ZtP+h26uoAOSNYje19m9jiz+dKmCGSNDRW8aXBSWuQadptA46QtTknHVEE3k8030IoxBwBKH3P3RISCSlSXYUxyun/J5whWSYSFFJyCxYsADDhw9HfX09xo4dizVr1ux0//nz5+OLX/wievXqhaFDh+Kyyy7Dp59+Wv2J822fuYSfvbqqjru/8sSro83/yvtfXSuh9PgC8YpiLPGqzSjfq6bslfH8r1riVT6uhjh2bUYh68H3ypAv/3mzxIvaj5xzBqwXNT/qpTOWew98+2T8L51z6syNIpNV/ldG46V1POJ+Eccjr4Nx/GzW/6L248K9fp3jscdGcF6dZ8GhkPF8LwqV9XwvnePpoDKe70XOJev5XlxMX0fY98Q2nqcCvYJgQxtZFYSLFy9GY2MjZs2ahZdeegmHHXYYJk2ahE2bNpH7P/TQQ5gxYwZmzZqF119/Hffeey8WL16MK6+8MuKZC4IgCIIgmMeWNrIaMp43bx7OO+88TJs2DQCwcOFCLFmyBPfddx9mzJjh2/+FF17AkUceiTPOOAMAMHz4cHznO9/BH//4x+pP3tEGdOwI6yoqFEx5vQUq15DYRoylqpbN9zAMt2qZGxEgE/GJPxj5YWrqr83wex3q5BDqNMTmVC1baYZdAfpa3Qnd0fDeU+z+emXoNMPWCb9yj+d61XJgQu5zCCS7IMXk/KK41ijpii5UO6ZabGkja4Kwra0NL774ImbOnFnclslkMHHiRKxevZocM378ePz85z/HmjVrMGbMGLz11ltYunQppkyZUv0EyhpTUwUkbJHIFY6USNQoSGHnJGoIQlVWZVxLFKNQFKiG1hqVvdR3GSXguHno3JVUdHIIuTqCLzrLb4J/IDeX0fDCHdIQmyDsZtiVzpGUqmUKTq5h2G1tgGQXpJBjRSQC0CsqaWlpKdleV1eHuro63/42tZE1Qbh582bk83kMGjSoZPugQYOwbt06cswZZ5yBzZs34ytf+QqUUujo6MAFF1ywU1u0tbUVra2txX8XH0pHe4lDSIo6jxCJhNwnvxqZwjEJBSl5Yh9upXCBKa701ksOTju1DFoEYpLCv1/wimXTBcC2WttEUckcdtUyp2K5cx7pqlqm4BSkJGGdZcCtghSKNIpEnaKSoUOHlmyfNWsWrrvuOt/+UWkjilhVGa9atQpz5szBHXfcgbFjx+KNN97AJZdcguuvvx7XXHMNOWbu3LmYPXu2/390FYl0ke0g9iGefIbYjxBwJe7jzo5H9TqMYBk9aixHiNJhZaqFjW8TvW4x09Hi9iakzQH/ZHTWWuaGlim456Xg7Md2JTUarFPPVkcQ6rRWi6KS2aSYMO0G6pyDGwqPoqJY57zl53BlCT0g2cvohU35fXKiSKWruq/aMQDeffdd9OnTp7iZcgeDEkQbUVgThAMGDEA2m0Vzc3PJ9ubmZjQ0NJBjrrnmGkyZMgXnnnsuAOCQQw7Btm3bcP755+Oqq65ChvhTfubMmWhsbCz+u6WlpVOpt7cDbd0EEGUDkKFgv9BTxK+r6dCy6YbY3NBypqz5HSesDFQTWia+kIjvI67jGEVomdsQ23R7Gs7vnOkVU0znH9LPh5qLGz9KlSmftFn3zvX8Q/J4lkLLnHOYboZti7jlH7ro8ungZTx4VQrTrv379OlTIggrEZU2orBWZZzL5TBq1CisWLGiuK1QKGDFihUYN24cOWb79u2+C8tmOwULJUqAThXe9SC4D0QQBEEQBCFqotJGFFZDxo2NjZg6dSpGjx6NMWPGYP78+di2bVuxsuass87CkCFDMHfuXADA5MmTMW/ePHzpS18q2qLXXHMNJk+eXLx4Nh35zlf3f5dDhoeJOk7C0VOE8+dRS+U5nmtYPrbcMQT4ja8Tk2uosUIKN/rIXpOZtU/wXEPTBSnJXUbPXDEKEE1z7aTmGhpvfE2Q5DAyRVDXMC65gWyymc5XtWOqxJY2sioITz/9dHz44Ye49tpr0dTUhJEjR2LZsmXFZMp33nmnRPVeffXV8DwPV199Nd5//33ssccemDx5Mm644YbqT97WAbTtWKmEDBlT1jC1XwdPJJJikjqeS7mGjMQy7jjTuYaUqNGpKI4i15DfdobYr3wst9UNATvXUCNf0HSuIXkOs4ezguk8yLTnGhovIHE81zBuoitu8y1hZ534dzamSmxpI09V4ycmgJaWFvTt2xcfrWxEn97dkjrr/Q4c6omkTyoRtIYYW1vv30a4fF5tL97xqG3E8RQhQjqUX4gWlN8Rzav2Hrflld81ZR+LyL/ME2KVquxtJUQytV8HkZNIH48YS+xHjSXPS2yjxBS1H5VrSPVJLD8etQ/ZD5GYB7UfVbjDmQdAO7jcuVDHo7bx50I4uMTzJvfjbis7Hv9Yvk1ac6OgqpZ1rpXC9Jy5z4wzVucaOMevvCORg80UOpQgpNARTuy56JzDYEFK+Xw7Wrfj+fmnYsuWLZGnfnXphqYfTESfuupaxbW0tqPhp89YmXe1xKrK2CSqox2qfYfI8EiHkLmNcgPJcDOvSIXrEJKhZUI4kmFeQoh5RNjLX1TCcwO5oWVyLBkKJv4ypsSAf1NiQ8tBw8qV99NoY8M+Bw/qd4X9h7aEllm4FFomjxfQNQxancw9fjWkPbSsIxCdrTKOwCG0RWoFoS+HkGo700HE3juI/WqYbWeIvEKV9ztp7BVSFPH4CAvCZIUyRzRWOhZ3bPnqKAAt/mqoaDtzhRTXQ8u8imLeOSmMh5YdamNDnsPs4ZwhKW1syOMFnJ/p1VFMVyinKbTsclubQHQt6F7tmJiQckG4Q7QpQtR5VKEJdxuz+ITd/JrdEzHcghRuMQp5fGIs3SCb18aGallDFqkwc+10BJzpFVKCFpDorI5CnyN+bWwoouhXGJzgrqFLDiH3HGEL0ShyGaX5dTTNr13D8wK0neEW8DmAtbYzgiAIgiAIghuk3CHsoe2MThiZtIeI/aiwb546HpHIqpF/6GUIR5STQ6ixEgoVCvcIN5D6i4r6y4UK3XJdOTLc7PhqKJxj6eQQckO3WpXHzLlQ6JyXhHJ5DOeQhY2O88ldMk/nvDYqlKM4JxvHw8g6BHUNYx1GjqjtjC3SKwjb24G2bm9MIkyrFUamik8oQcgM+6oCkWvILT4hRKdHPHpOGNkjpJlHrTbCDUlzl9DTCCOTIowUF1RRBTHWUhi5HJ2iEur4psPIZGEIN4xsWPxxfwu5q3eEDy+MbFrARlEYonW8sFcqocZKGDk+bWFCRmelkjiQXkHocwgJYdbmF2Fkfh+3X2ENcQ7TayhzHceAfQ3ZPQ0p549yIJnijxaY/vPWEF/KXHFBuYYUbNGp4RpSVdDl7W5sNL6uvF+6K5RN//DTwix4riEX066hjfzDKNZP1iJFrmHYja8jR4pKkolqz0O17xCEXg23gIQSdcwwMhVupgQhdw1lyvljij+yajlDOYllH3Aq7Eu4hnTlcbD1kyuN9ZT/vFlCnFLiiruGMvXdzRedwV1D6oe+fD9udTJXEOtUKOuIMG6FMvf3wbSY5GCrubRpkrCGciRrL5t2DS2toRynCuVClYuRhYIIwmSi2vJQ2W6CsJYp/ihRV0v8enFDy6RryM0r1HANyQ65/uOVu4bslVAoAUeFmwnXg95GiE6ma0i3nWG6d/5N5H588cc8LyOfMYrl8rhCj+0acvM0iaNFIfToamlzP3zmnURqa/iijp6Lu1XLzruGTEy7hmGTtOrkpIeM45PtKAiCIAiCIIRCah1CtBWAbg5h9/BxF9wwsmonCj7INY81wshEONdGGJl076h8ROKtxV3lRMLIvDCyTpPrpISRybkYLkgxaazEMbTMDSNrOX8GzxFJE2pqrBSfsMLIsXYNJWScTFR7Aapmxy+MRy2UqhFGVoSYZFco2yo+YYSROYUnAH9pPJ38Q9fDyNzl8djirEyw6jS5jmMYmYpn5KkCf+61ORJGNo/Z4pMoCmZMhpHDXhqPOmc152WT4OKT2OJlKuVp7HxMTBBB2PXvTwkhRfyis4tPSHFFfKhcdw3LxxKXwHUNTYs/qt0N1cPQlmvIFSbB+xVy+zC64xpSsFdX0el/6LBrGEfiVnwirqEeQYtPkiYkvawHr8ovjmr3t0lqBSE68kD7DgWg2gnBoeUaMhtdU+sgu+QalgvHgO1qAPPFJ6Zb1thyDUnRyTgvu0DDIdeQFnVmi08oQncNI3AMdYSES8UnFCYFURQCzvR52VhyDcMmNmHkjFdFD6tuY2JCagWhai9AZbv9nBDiT1GCkMo1zPGcP3aja0okUh8OsvLYsGtYfjxqPWLDrqFOo2vTriE7508jX44Wfz2fl9upgnt8064hd2wSWtZE8RsaR7fJRuVxFM6fDjqiM27ohJ/LRSIlGiMn4TmE8QluC4IgCIIgCKEgDuFneITzR1YeE6Fldh6g6fWSSScx5PWSQ14rGdBb+cSl9ZKpcGaGOG/Q9ZL5Th3luPr303H++EUl3P3i1cPQdOGJNUfP8NJ9rqyXbDqMTM7DVhhdik8iI+l9CNMrCDvyJXmD3BxCKozMrh7WCSOz11Cm4mW80DK5XnKhbD/q+GRFcbC1kjvHMgtIuFXGCV0vWWutZG44V0PoRRFGNr1eMgXnek3/DppuT0NDvI+5SZkahB1Gdj1kTuZ4ShjZwkwCkM10vqodExNSKwjRXigRN5TQIx3CPCEkKJGYY1YZk9XIOtuYDiG1X7n4o/ajjkXmFRL3hMwh5Ak99lgyh9Ad15ByCIO6dTprJdOmQnD3Usc1ZItTnSpjYht3LEc46qyVnJQffq7Q03ENyfOWiSmd+2mrGpmLS2KXIulVxsgiQA5hKDMJhdQKQtWuoLoJQi/PcwPpQhOuaxhFD0Pi3cctPuEIO+6xCGfNI9dK5rqBTOFIOo5mj8cVYsSdMiridJpcRyH02G4gO7RM/bgQxTGGXUPO979OUYnpMK0e1MW6EzI1eV6pRnanGjkurqHnBQgZE9/TrpJiQZgvqVpS7UToto0ZMiY+GOyVT9jOn04PQ56Dx6lG9pjjyFVPKKeOqgA2vPKJaeFIVSPTQoz4ktPoYchZy9i0a0hWIzNFk44baNo11AkPc36XdHIUI4jSOkXYK5+YzvlzvYehy8RF6LGRKmNBEARBEAQhyaTXIexQUN3+eiPDw8RfMlrVyHWEF8ANBXObWlPuIhXmpY7HqUamKpFJ15DYRoSRM1Rxi+Gl8MyHjP3PJ4qm1n5jQSNXSiOMbLpIxVZTa5NL4WmFjGO4vrFOmDvs0LLJnoa6x4vivK5XIwelfB4FF0Kv0pg6meQ7POS7PagMU+ixq5EpMWmrGrmGuW4xIz+QVYkMsKuRo2hqTRepxL+pddCG1gBfwCS6qbXBcLNLy+CZr1Am3lOW4tyurGhCHi8BIV5h58jSdQmlkPdKPqxkbiBVQMLNF2wjxCS1NrLr1cjlIs7ksSpsc6kaGYp4thFUI3cwegcG7V9YzViqhyGFS9XINnoY6og6bg/DpIgLHcFqMq8wjtXIcVytppxY5xVmMhUq1HoYExPSKwjLHMJCByEQTLeioX41uNXIBappNLf4hFmNTAk2X8g4uNDjhpHdqkbmjQV4YWST1cg6hSFJqUbmO528sdR3NxVGLsdkC5tKSDUyD+POXwzDyPyThBtGtrEucqhIyDiZqIIH1e3DRWmVrEYrGmphbnq/4BXFxtdG5gg2QugpykXTanXj32Q8jEyFgnXGMsPIHdRqLQHDyFF8z/BdTiKMbLiylytOKXTCwyYbU0chHJMMx110SfxFEUZOQqg6Nq6hOITJpFAoDRlTHyCu0GOHkZPQ1Jrb0JqdQ2gpjJyIpfCYYVCdMDIzZByFu2ir+ISTLsedBxcJI9OYbGqdlDAyFxtOZ9iFJ5GTcEEYn5kKgiAIgiAIoZBahzDfAeS7uTp5IodQ8grBKyrhOn9OrXLCzRc0m6eos8qJfxzxF7+tvEKNXD7TDaejWBtZ8gopJK8wbOKYV5govAA5hC60y2GSWkFY6PBK+hop4kMleYXgFZUwVjgBKqxywt7m36SVV6hRVMKtPKbmzBdxxJetKt+HGked07+NIzg7x5oVnUlZ5UTyCt1B8gqTSbngVC4UZyQ8ZJxaQagKpbqD+kBKXiH84sxwRbFOyxqdvEKqKpjtEBrOK6RgVRmTI4Pn/FHn5OYQcs9h2l2UvMLkigvJK0xme5pY5xWKIEwm+Xxp2xkqZEy5hqQe0gkjUw2sXQ4jc8UfVzhSWGpPQ2E6jGy6PY1/nISRbYSROSHkSvOgcD2MzBc11FYJI8eOgGHkxIWQpe1MMil3COmUNyLkRWzLUGFkKjxMLoXHXOXElTAy1/njtqfROJ5L7Wm4YWST7Wm4K5zEsT1NFGFkLhwxyRV6OmFkl35bue6dK4LIdfGn40xy5yKhagNkvAAOYXzucXy8TEEQBEEQBCEUUusQFvKlRSVURTFpShEuEjvs20a4ZnXEWGrZuzrC5eO6hibzCnUKQ3RCy5b6FZLOH9OFDHvZO+6SdxS2+hWy52I4r5B73qDL3nGX9tWZL9dosLTMsDWyjBC5K04lkPDQctKRHMJkogpeyYeQEnpkoQklHJlVxtxQML2fjtDjjmXsZ1jAxXHZO7oghSk6DS57F3TJO8BiXqFhARfFsnccgramAfjXxZ+L2dhyJO1UNPIeQ881dCjU6lQDa0ZeYdAl7wBHC00khzCZ+BxCUoMQP4akfmFaHMyWNdQHgdyPcBe5IlFRVbEckajjNnJzA3XGMtHKK9QQieRciONx8gq5S965BC1YibxCDfGnk6enU2gS9PjcsRRR5BWaFphUJTz1kbeBS7mGUeQVClUiDmEy8bWdYTamNt2vkPwrSKeNTS5kJ5HaR8P5IwtNdCqUmYUmFOxl7zTcQG6FMnUd5WKK/sOz52KUStuMh5FJQWQ2tBzFeskcIWp6HlxMt7ZJKlEUwYgISwEiCJNJvsNDvtuPJztyGUW/QmKbRwhMo6Fg7jZLLWZIDOcVcuG2ouGO5YaMOZBtZyrMxAZc4Ri3foVBexVWMw8uaetXWA6VU2iyV2EYWFv5xKFweBzwPI9ehKCHMXEhPtJVEARBEARBCIXUOoSFvEKh21/+ZGNqsqLYfyx2A2vKDSR7ExJVxgWiypgMN4dcaBJ0hZMwtsWwgTV3LMc1dGlJuoA551XhcgNrk2sgVzMPHVwu5KhE2MvUsefhkLPmdKGJxhrI5YUmVOFJ5HgBQsZUBwpHSa8gLJTqCbLK2FahCbONjelqZFahCXl8DWHKDQ873sCagl7izlyhCad5NcBvYG0LdqFJApo6S6FJJaiLc0N0uVRoooPkOBpAcgiTSXuHQns3AVRHCT1ub0KdQhNuD0Nm/qH5HEKGQ0jhUqGJBjqFJqQQZRep9Hw87uogOusbS6FJ8JVKpNAkfrgumlxyJlOJtJ1JJoVCachYFfw/1GSU0nQDa6L4RKfQhNzPpEg0HTKmkEITpKnQhCKK9ZJNFppwmlcDUmhiC07zasBiqNXwWB1EdO4EcQiTicqXfjl3tPs/BDV1vLxCssqY0iXEOdh5ha40sOYen0InX5B7PIfyCilMikSuaHI9r9DG2sPV7GeSOOYVuuSQcRAh5VBeITORmJtXaB1xCJNJuUPINbm4eYVkDiE5EV5eIfnNz91mMK9Qy22kIGPrxH46DawNw80NDL+BNe8LM44NrClsrEBS8by+8H3weZgONxsXiRGsfMIVMJxCmLgJWCFGZLwADmF83o/x8TIFQRAEQRCEUEitQ9jRrtDRzdnKU2Xv7LYz/m2UeUW1tskwC0244WFuuDlwTM50M2xmpTBZaELNj8Jw5TGIuVBQDUlJ15C6NqqZqa+QgRfi1Skq0QlLm15XmcrTo6DOoVNAwtnPlZB0pfO6FH1ju4HER0WWuCPGyhJ30ZHwHML4zNQw+ULpq5D3v7qWtyt9ecFfingVlO+FPPNFXph/P0W8SLp68ezsxYUaS91QCmo/7vE0Xh4yxIv6j9jP878oyONxx37WJb/abvkZT/leWeKlQ1dqTfdXFGQ9/4uLrTlz0LkuCpev1XUyGeV7pY1MVvleqYT6IHFeAViwYAGGDx+O+vp6jB07FmvWrNnp/h999BEuuugi7Lnnnqirq8MXvvAFLF26tLrLCzRTg9i4aOCzHMJ8t1fB/6KFXXDNwZ+c8r+4Qq9c6VL5iJ03gHgR5+WIRK5wNH3zjN/44HCFI/94PY81/V0UhXDkQgkiW6KGc17ufXJJwJoWnUkgCvEnAjPGdDmE1b6qZPHixWhsbMSsWbPw0ksv4bDDDsOkSZOwadMmcv+2tjYce+yx2LhxIx599FGsX78ed999N4YMGVLVea2GjLsueuHChRg7dizmz5+PSZMmYf369Rg4cKBv/66LHjhwIB599FEMGTIEb7/9Nvr161f1uQuF0irjAhEFZBeQUKFlZriZ22LGmcpj00UlOhXFXCKoPOZitj0NL3StU3lMwe1/mOQVTUziUuUxV5wkIawYxZrHpiuPXW9jU36OJLxPSohopZJ58+bhvPPOw7Rp0wAACxcuxJIlS3DfffdhxowZvv3vu+8+/POf/8QLL7yA2tpaAMDw4cOrPq9VQWjrooEuQdi9ypj64FLCjD4WZz/uEnfkR8hW5XEZrNVMNI5fcb+EVB6zj8cQibTjKJXHUTSE5uQQSuWxHUHg8twEBG5FQ7WmiRyNtjMtLS0lm+vq6lBXV+fbva2tDS+++CJmzpy54xCZDCZOnIjVq1eTp3jiiScwbtw4XHTRRfjv//5v7LHHHjjjjDNwxRVXIJulfjwrTJW9p2G6LnrixIk7JlPFRQ8aNAgjRozAnDlzkN/Jn9Otra1oaWkpeQmCIAiCIETF0KFD0bdv3+Jr7ty55H6bN29GPp/HoEGDSrYPGjQITU1N5Ji33noLjz76KPL5PJYuXYprrrkGP/3pT/Ef//EfVc3RmkO4s4tet24dOeatt97CypUrceaZZ2Lp0qV444038L3vfQ/t7e2YNWsWOWbu3LmYPXu2b7vKl/5Rwl6Ag+n6GF/zmIAMD2vsR2KyrI/rGjLdwNRXHpMhT/9Gas1jLtzKY9Nww9LcymNXSFvlsY01j5PSDFrWPHYQjSrjd999F3369CluptzBoBQKBQwcOBB33XUXstksRo0ahffffx+33HJLRW1EEau2M0EueubMmWhsbCz+u6WlBUOHDu0sHOkeMia+lclWNMywr1NrHlMEXqkkghxC06uXOLTmMV849jyW05oG4IsrnaXrjK95bHh9Y3oNZY3j+VoABT8+F9fbyaQJ10WTS6umJAovU31O4Gf79+nTp0QQVmLAgAHIZrNobm4u2d7c3IyGhgZyzJ577ona2tqS8PCBBx6IpqYmtLW1IZfLsaZqTRBGddGV4vTlUCl6VKEJBS0SmcKRu+Yx89eFXvOYV7hitl8h03Kl0FnzmJsuYWnNYwqTRSU67h3pLsbMgdPF9BrCHHTWPNbBxrUKQqzREIRccrkcRo0ahRUrVuDkk08G0GmGrVixAtOnTyfHHHnkkXjooYdQKBSQ+cyR/Nvf/oY999yTLQYBizmE3S+6i66LHjduHDnmyCOPxBtvvIFCNyEQ5KIBqg+h8r8Mt6LRgtmKhnWx+QotZljz0Gg7Q6HTYoa7XwStaOi2M8Fb0VAjgxLHVjTUnE23ojHZAoY+vp37pIO0ovEjrWhoUtmbMKK2M42Njbj77ruxaNEivP7667jwwguxbdu2YgHuWWedVVJ0cuGFF+Kf//wnLrnkEvztb3/DkiVLMGfOHFx00UVVnddqyLixsRFTp07F6NGjMWbMGMyfP9930UOGDCkmX1544YW4/fbbcckll+Diiy/G3//+d8yZMwff//73qz53Ia9QICpmS/chtjFb0VDotKIhj8ddlYQLx9UzvVIJRQxb0XBDwRTcNY9NjQPcb0WTdqfK9VY0En50B5fD1y7PLRCeF8AhrP56Tz/9dHz44Ye49tpr0dTUhJEjR2LZsmXFmot33nmn6AQCnQUrTz31FC677DIceuihGDJkCC655BJcccUVVZ3XqiC0ddGCIAiCIAiuMn369Ioh4lWrVvm2jRs3Dn/4wx+0zmm9qMTGRQM7FuEo/pubtsZ179jrIDP3o3L+qBNzm1q7YsHY6k3Ib81kFJ3Qr43ehMyhTkE6ZCH39eO6cknJ2+OGB112g7hhWdebVevgjINnujt9WESQQ2gT64LQFoW8QqFbmItsTM1u/0IcP4qs8ErL0vn24+YHBqsy1mpWTc5DQ/xxYVc3+zdxW9HohJZNVhnrtKKJz1eZGVyu5HW9WTWFK4LDlXkIMUcEYTrg5wv6t6WuN6FJuL+2TPGn1ZvQElxhZxIq5497p7i9Cbl5iuxWNMzzRtGbsHx+rojGOBD2SiIuOXBCwtDoQxgHUisICwX6B6YnqN6E1A+ptd6ExOyM9iYMugZypf0oTFcBW+pNSGG0NyGneTXAFpc6IpGLSyHTsF0zl9dPrnRel4Qtv9K29EIiic4wcalZtY5QFof1M8QhTCaFAiNkzK72NTYtfbi/JK4kLul8e1sSjtzehDpVwCZ7HUpvQj3x45KILce0qHP5WgXBOgkXhPGZqSAIgiAIghAKqXUIOSR69RLufoH7EBqutNFZvYToL+gSNvoQ6hBFaFkHGy6X6XPaWr1ECI6EVVNAwh3C1ArCfL60FQW3YJeCK/QigR0ydvjnJYrG1LZyCB0RTrorWpSjU8msd17/Npfy4Mqh7pPpIhgJ+4ZPFOJPBKaLBBCEMQrEplYQcqByCLmtaCi4biAJ9Y2uUyxC7mfhV8OlDHDDq5d4hn/og/chTMbqJVG0KrNV4OEy4lb6EWFGU164krjqbqkyTiaqvKhE41vftKZxqkiF0YdQazk7LlG4hg5hKxycJsRJC07al7MT9y44cb53npeB51W3qgHZWcJRUisIg8LNIaTe4Gw3kAmV8xe7PoQUiVbYfoKGkbnjZPUS85S7ZtE0efZv465vLAiCASSHMJkU8qVhKZ0aCB13UWc5O7aMcLkVjY5rqLMknaXl7HT6ELpCfL7edmAjFOxS+DmKuZB97mLi/OyMLHFdppezM41Wv0JXHLzyv0ZNJz4HIeGCMD4zFQRBEARBEEIhtQ6hwISzUgn7WG6HbqNY31jHDeSsVMLFVlVwHAm7kjnJuYzOuE1C/PD99jjwoUi4QyiCsEq4Vcb22s4wT2zyF4cQekoZzlHkhngdzxcMGxvrIgNuRHO6oKqbo1jfWBCCIusvE7gYMpYq42TiW7rOcIsZ+pzUNo03uQt/McUARa35a2Ee1WBL2JnEhe9voTIu5Ti6gjiawk4Rh1DoCdORUOp4VKEJl9hVHnOFbsrdQNNEsQKJTq9DsqkzKWqif9faapAdR1EnokuILSIIhe64ngYnxI+guYbcca67jZSTaCvsG1RguS7CBEEwQMIFYXxmKgiCIAiCIISCOIQhYboJtTXCtkRtWa4SbhYMErc1lQVBCIAUlQhxwpk8QAr5hYzdknRR5BUKgg0kl1GoGs8LEDKOz3tKBOFO4C5Tx8UpUyqoOGOvW+zSxTJx6gEJSYSboxjHYhFBSDySQ1jKs88+W/H//dd//ZfWZARBEIT0kcko38vo8bPK9xKEqukShNW+YkLVMz3uuOPwox/9CO3t7cVtmzdvxuTJkzFjxgyjk4uSQsH/EhxCFfwvnf0EQRAEoRpEEJby7LPP4rHHHsPhhx+O1157DUuWLMGIESPQ0tKCtWvXhjBFQRAEQRAEIUyqziEcP3481q5diwsuuABf/vKXUSgUcP311+Pyyy+n+50J4cFdpk4IDnd9Y+KtT61vLAhC9FDLwEkBiVA1HqqvqYvR2yyQl/m3v/0Nf/7zn7HXXnuhpqYG69evx/bt203PLVK6qsm7v5wnm/G/BEEQBEEwjlIq0CsuVK0gbrzxRowbNw7HHnssXn31VaxZswYvv/wyDj30UKxevTqMOSaGGKcWCN2RBykEJK/8L6HTrSt/hU3YhSxC8lAoBHrFhapDxrfeeisef/xxHH/88QCAESNGYM2aNbjyyitx9NFHo7W11fgkbZHJ2p5BiFDrhSUBh8SZcml9OMEJpJ2MO0jIWKgW9dl/1Y6JC1ULwldeeQUDBgwo2VZbW4tbbrkFX//6141NLO54ZEPfCM4bttAjF54lRBi5TWNutoSeQwJTiD9c8SciURDcQ6kCVJWdK6rd3yZVC8JyMdidCRMmaE0mDsQit1AIhog/H4WkLMEoOIM4c36oohfBPZLuEMovoCAIgiAIQspJ7dJ1mYyHTEra5GiFkcO2RFNmucYpwTiNSKjWLHFzA+M2XyFaOquGqw0Zx+dLJbWCsJwMle3NHauhaajKNup4XgQVcKHnH7qEhIcFg9gSkjrnFfGbXBIhbAtq5/+2QNJDxiIId0LGsEAyboYlVcBxryshoi5o0jF3nOt/oVLf81TuYl7yGQUL5CW/T/iMIG1k4hQVEkFYJVyR6LxW0XBE44bnJbl/kDmiKCAx7UpJ0Ys7SGFEMnDGXSz/rXXAAJEq44SSyZa+v7idU+hjUW9U3i8fJRy1wsOmVyvxfSh5N4pcxlAvts7bz3klzsN1V68c2uWLfh6d5zX7wxF2aNWBSJhQJVGIJmeEmWHY1yUh48hJxq+nIAiCIAiCEJjUOoRB4a5eQhWL6DSr1jK+XA4POxAG2Cmkhevf5Iqj58o8BLeKNoyH6hPqXgnCzpCQcULxytrO2Koy1sEj5sytFCb3c1k4UpgOQRsON1NfBDpfDia/WNJWoOGSOAtKIT6/K4KQSJIeMk6tIORAiUSdymPSbCJcQ1KXuCLWTC9Jxz2HDo7nFcbpCwNAjGrmdo5OSlISBCaXNF2rDuKaJh+pMk4o2WypxjJdixGJBuEKMXZ1jMFCEFtrGTsu/kzCFZL0FxLvWTiQxy1UiTwzwRZJrzTvbExdpUMYozSe1ApCDpRIpHIIuVXB1H5c15AN10k0rYCD4vpKJcQDoj7gpvNEbPxV6VJzZVtVy0HvAXdu3FC9rWfhkpikHLe4uXCm56tzvKSLtShIukPo+K+xIAiCIAiCEDapdQjL1zKmcgO5hSbGewmS5wheBEIVn7Apd/C4YWUd5y+K8DA11rBbSYV0dfIFOWNd+muU6gcYtx6BrpOU69dxr8qLbcQJE8JCikoSSiYTLKUta7qohLmWMTvEyw4ZExXKYReu6AhHrviLoHrYNDZEnE5o0PVl5UyHPV0Ko5okCjHpStjXlXnEFblXnUjbmZRA5QaytQq3Utgw3BYzWsUcAVcq0cJWAQmt2H2bFPJmz0tg0g3Uq6ZNxg+BDSeNe87EuHwJEA1RrFssDmZ8Uaje8YvTxzu1gjCT7TlkTI7TMKqoZtW0oUW8hXT6C3L3Mynswj5+NTCFng50eDg+fxkC6VsXWEeIuSwwTY+lEFHjDkkQ4rEhgEMIcQjdJ+OV6hNawPHyCvlVxtS2CNYtNppDyFy3mH185limgPM85lIylqCbVTPbxzC+WFxqccCvHuaFoKMQrLJusZ+0C44kX3/Y1xZ43WJHSXoOodvJVSGSyXolL3of4pWhXsr38jz/i4uXIV5Zz/ciyXr+F3e/DPEKCnmjmMenxlJQN4oLeZOpF3wvRf332V+O3V9h09UTq/uLS0F5vhd/LO8VBTrXYZK88r90KBT8L1tzEfwUCp7v5dLxhHDoajtT7SsICxYswPDhw1FfX4+xY8dizZo1rHEPP/wwPM/DySefXPU5nRCENi5cEARBEATBNRYvXozGxkbMmjULL730Eg477DBMmjQJmzZt2um4jRs34oc//CGOOuqoQOe1LghtXXg2U/oqdwwzWa+zNU3Zi4JtNhGuIWlyMd07j3hR7hq5HwVtf/bs8nEdPe45dfajbqjO/DSgXEPKXeSODUpeeb4XlwLxch3KIdNxMF123Ew7sy5fq8CnkPd8rziewzWoqAznVS3z5s3Deeedh2nTpuGggw7CwoULscsuu+C+++6rOCafz+PMM8/E7Nmzse+++wa6PuuC0NaFl1MuEDtFov/lZZTGiysSmQKOHR6mLs5geJhCRzhyQ8E6IWPm8egPuNnwMDfEwBKSGuGKKMK+LgkOnblw7hNXiNu6J7bC/ORciJAp/SLC6eXCxFL4VcLIySeKkHFbWxtefPFFTJw4sbgtk8lg4sSJWL16dcVxP/7xjzFw4EB897vfDXx9VotKui585syZxW3VXvjvf//7QOcud/yoPEKq56BeL2SdAhJeY2oyt5C7H2f9YW7bGdP9BXX2kybUPvQqVnkFH9x8Pu4yddwilaTiUpUxBbfXXxqcpDgjwrMyOmsZt7S0lGyvq6tDXV2db//Nmzcjn89j0KBBJdsHDRqEdevWked47rnncO+992Lt2rVVza0cq4IwigtvbW1Fa2tr8d9dD8XL9lxlTMHtOajTR5mqHrbScxBgVRnzj8+8UVo3z210RBxnrOk1gOPYikZH/AS9Vy71HLQl/lyGmm8cew7G7b4nDZ21jIcOHVqyfdasWbjuuuu057R161ZMmTIFd999NwYMGKB1rFi1nQly4XPnzsXs2bN92zsjljvvQ8jVL5H0HOSKRLItjiM9B3VazJh2CDWaUHN7DnJDyTbaEri+2ohpXHfXXDhnJdLk6EUhuGyJOmfEJPOvrkzZfsqBcECQ1Iqu/d9991306dOnuJ1yBwFgwIAByGazaG5uLtne3NyMhoYG3/5vvvkmNm7ciMmTJ+8452dtCWpqarB+/Xrst99+rLlaFYRRXPjMmTPR2NhY/HdLS4tPqQuCIAiCIIRFnz59SgRhJXK5HEaNGoUVK1YUO6gUCgWsWLEC06dP9+1/wAEH4JVXXinZdvXVV2Pr1q249dZbq9I7VgVhFBdeKU5fXtNALl1HuG2ZjN/14YeRmftprEfM3o/t6pVuI5tQmw4jUxA3yngTaup2Gv6j1EYTatNhX538Pp0m1KavI2zXkOskVNNjMMg8qhnrjIvkEHJP+CT9XlXbpaFrTLU0NjZi6tSpGD16NMaMGYP58+dj27ZtmDZtGgDgrLPOwpAhQzB37lzU19djxIgRJeP79esHAL7tPWE9ZGzrwjOZ0obUOiFjcht3fWOmWOOGffm5htzYN6OohDy+hnDUKUgxnKdIhofJ1UY0cgO5oWWGANQJ+1KzcCBKs1NMF5q4FKotx/VnkSZMCx+d40UR0k9T2sDO0AkZV8Ppp5+ODz/8ENdeey2ampowcuRILFu2rFhv8c477yATQg69dUFo68Jraz3UdheElENYQ4k67jb/8WiRyMsD5FcUa+TuBS1I0RF/pquMuVD5goTg0hJ6GusbB68yDl7Zm2TCLjThCnFbgjOK582tKGZXIxNvb+7xBB5y76ojyGpIQSMb06dPJyOlALBq1aqdjn3ggQcCndO6IATsXHh5yJhsMaNRUUyJRG5FMYlGeJjdYoaCU2XM3aZT8KFzPEstZkyKP+7xuOFcLjrhXH4YmTovbz8dwhZJtgpZTItOroBzBakoFsKiEKBPaJz++HZCENrA14fQoYpiyjWkQ8YaFcUBcwiddwMtVRRzCVv8ceEKPdcxLX5stI9xSvylKDQoFcURELCi2FWidAhtEL9GboIgCIIgCIJRUusQetnSvEGnKoqpMDIzr9BkRTFAVBWbLiphundkRbHpvEIN6EITcxXF3OOZrs6NY0UxNb+wc/ekojgZyD3hu8RpvFdRFZXYIrWCsDxknCULSPzjnK8o1hB/gcPBcawoJqbMLSqRimI72KootlEIore6jLl56GK6gMQGaasoFioTVdsZW6RWENbWeKit2fGg+C5fBBXFOUL8MZ3EwGsUV9wvYA6h4RVDSDREIlv8sXP+bBSVBC/kMA1XNLk0Fwru/Dhf8lEIzrStUexKgUschVkaHT3TFED/Md3TmLiQWkFYXmVMCbgscXcoJ5G7DVl/2FOngIR2HA2KP2qs6YpiCu7xLOFyRTGXKMK5fHEVfCyXsJ3EOK5RnNQWLlRFsesFJLbGGofxQYtLAQlFAQGKSpitwFwgvYIwW9qYOlsTQb5gLZUbGDxf0HTImFyFhHUs5hvedL5gDNcopl1IXv4h53vU1hrFUVTSmc4NdHmN4ihax6QduSd8bNwrr+xDUP5vGyQ9h9Ad20UQBEEQBEGwQmodwmyNKgnrxrKAhB0ytpBDqBNapnC8gERnlRNuNXK502e6eth0g2jTTa256DiJJnMhuRXFph0Ea6uhSAGJ0ePJknTuIUUlCcXLlGoHbgEJuZydrQISKgRtI4fQdAGJYZEYtyXpKh+v5y8WUuixzug+pldhSUIBifFcyxgWkIS9CkkcRZOEw8NBBQgZM//ed4L0CkJPlQg5bgEJpUHIApJa/wFjV0BC7efSknRsMalTKexGAQng/yLiu1lmC0iiWJIuCqFH4UorGikgCY4UkEREnJLjDJH0lUpSKwgz5SFjxxtO026g4QISC21ntApImJBFG1JAEhhb57VVQBI0PGzNNXRJNDiC6/fEpfCwyXvFrSh2oWCEQ15V/3mNyaUBSLEg9LxSjUGFgjNZ3jZ2viAzxOvVEiIpiobTHLdSZ9Fnihg2nLaRLwjw/tLkNpw2nS9IziUh+YImcb3hNDc8bDpf0EZ/QdOh8NSvW5wCxCFMKJkaVSICuW4gqVUIAUe3mGGKRGq/SELGxLaasrdIFPmClhpO64SM9cby8gXLBYz5Xn1mw8N6c/FvcyVfkDsXl1riuNLQWRcbPQYlPCx0IW1nBEEQBEEQhESTWocwm1XIdgv/clcboULLdBFI8NYx3FY05HnLHb1qxnL2c7yARKfhNAXXDeSOpQjquHHDykHDz7roFJDotH8xHTJ22Q20VQHsMuKshZAvyLS54rwKCQfJIUwovrYzGv0FqbAvKeqo0DK5HxW6JfIKSfFHhZupnESDVcZMoUcWkFjKF6Sg8wqD5x/SY3nfDjpLxvnH+bfxQ8FmK5R1iEJ0cX/POD0GTYe9k9Jf0MrawwnpL5gEscstICkXlwUHxKbkECaU8hxCyvnLEgUkZMGHTn/BHCESo8gXpAQmR5wlJV+QnVcYfosZbt5a+fcht4DEJUy3nTHdOJt7jiD7VDNWq/iELeDccRe5/QVNCiLT1yr5gskn6TmEqRWE2ZrSkDAVHqaKSujq4eDOn1Z42HSVMSUSy4VYhnjLOB4eprBVLKJT8esfJ/0FudgII5scV4m0h4cBwyLRtJMYx3YyEh6uSCHASiXiEMYAD6WNqUkNUut4eJjYz2h4mNqmEx6mun9H0U6G2V+QG8413WKGL9gYbWdiGB423TrGRniYIo79BWX5ufCP5/p5TRI0POwqSXcIpcpYEARBEAQh5aTWISzPIZR8QfDcOsfzBSl08gX5OYm84wXNFwT8OYOu/+WZ1HxBaj/+PPzbkpwvqLMecdj5gtJfEO5/iTiGFJUklPKl67jtZNhNqHXyBdkVxSHnCwL+nEGuqOPmGhrOFyxotJ1xOV+wc2zps5V8QZok5Au6VPDhEk6JqTJimS/IJGlL0AVF2s4klExWlS5Dx20no7P8HDdfkCn0yHxB8ng6BSk9O4QutZPRawnjTjsZyRc0O9aVfEHTwplLUpefo84Rx7YuLgtd08QlX5CiM4ewWocwpMmEQHoFYaZUs1hbfo50A7kiMeDycwDPDaT2Czqu0twcWn6OwnQ7GbZgI87h8hdLmsLD3LnYWn7OdXEhy89ZwnD1cNhuYPk8Msw/usMk6UUl6RWE5WsZk24gTySynUQNocd3A5lh5IAVv162NtC4arbprDbCrjK2FB6mCLriSJLDw7bCvkH3k/CwecIWu643l7YVHjZJnN1AiqSHjKXKWBAEQRAEIeWk1iH0ajKlLh7l6FHVw2TlMdch5PUS1HL+qLFcN5BTHGK8CbV/k1Z+H7cC2HC+oE7+HTc8XL7NVr6g6Zw/0+6d6XzBoCZHFPmCXHfRVr4gFR42Tfn8JDwshEUeARzCUGYSDqkVhKjNljSP1lpthFloYrx1DLfymN0ChnE87rEoEUocv6CI8LBWvqCtkHHwUC0FR7BFER7mohNu5h6PIop8Qc5+WvfOoQKNKLCRC2krPGz6HLZWG4k6X9BVVIAcQgdSH9mkVhB6tV6JkKNbx1A5hBrVwzqrjXCrh7nuItM19FUQM4Welhuoky/IPR57RRP/WK4Lp+PMceDOQweu0IvC+dPJZ6ScLxttZ+K42ohLQtSkOLP1LJJMXIRdUJKeQ5heQVgWMmYXhpBVxkwBpyESQ68eBnhiL4LqYT3X0E5zadPVwyYdvCj6FfLnwjseLTDdqB6mxia5uTSFTnNpl5eWi6KQQ9zAGCmkMkQQJhSvNlsq5DQqimmnjhfO9TxmGJnrELLzBf37kf0EDa5UolM9TIpEh6qHddzAoILNperhKI5naz/OWK4baE/AWTktG5P3SnoJRkOchV1QRBAmlVympGiEFHrUUnPcYhFuOFfHIeSKv4AtZsixzJAxuYoI1/ljOnrc49H78Y6nE+I17QaW7yfFIhX2C7lYRHes71gabqCOA2e6WMTlEKy1QhNxAwOP9crGlv/bBknvQyhtZwRBEARBEFJOah1CrzZgDqFGbiC7WIS7TaPAg2wwHXSlEll72Eq7F9PzpbB1PG6eZhRuJWesjqOXZFzOIYzjEne2SGN4mEJCxgnFqynLIdQJD0dRQKITHtZZWq78eOSx/Jt0ev9xQ7wu9RKkx/o2GRWYptcedkVcAdGEh22sPxxFeDiKXoLGRZ3Bc+jkI0p4WI+gwtGFcDCHQgBBGJNLA5BmQViXhVfXvQ+hJTeQW0AShRsY0CGkBBy5TaeoRMtxDL+XIKVLwq7udb0IJG69BKsZW464gW67gXE7p03EDaxM0nMIUysIO4tKehKEEbiBOiuLmHYDOeHghLiB7dSPl4Yb2EEez7fJaLiZvRKK4VCrreNxewlG4QaG7V6JG1jF8WwUmogbmHg3kEJCxgnFt3SduIGs80bhBnLzBdPkBlLbxA204wYCfkGQ5JVFKFx2A3UEZ5KfGUXYbmCcxR9F0gWhVBkLgiAIgiCknPQ6hLXlDqGEhzkFI1GEh/n9CtMTHq60X9B5SHhYwsNJCQ+H3exbwsN28gozZdeV8f8sRE5eeewoR/cxcSG1ghA12VKB5nrrGGplEdPhYUbBiOvFInlSiAYP03ZE0Pw5aDWyzvElPMyHI2qSHGqMYhk9kyFop8LDhsWfDtxzmGwmzaVc/LmKFJUkldpaINdNUOUIcUVtY7uBwXL0Ku7HdfSosQZXF3G9WCRurWMqn7fnbeIGihto2g2MIg8wimKOoOeMoj2NDlw3UCqFwyHpOYTpFYQ1NaWijRB1kbiBTLFGuoHc5eeYoWVOM2nKqeOuMyxuoFmBqSXCNPaj0GkkTR8v+H5hu4GVtiWVKMK0VqqMJRQsbmCVSB/CpFKTKQsZM50/rmvIDd0GrPatuB/TXeSuDcxyCB1fZ5jK+WuPmRtInddWI2mtlUpi6AZyBKG4gfFzA5NCFG5gGsUfRaFAf157GhMXpMpYEARBEAQh5aTXIcyV5hA6FR7OEC6kRvEJVWVcUB3Etp5Dtdw+hGzXkOwRSJ3Dt4ncRlUUc105l8PD1H5RhIeTsM4woJfjF/j4ltxAHaJYk1nCw36iyA2MW15h+XxdmL/kECYUL5stFYFaBSQ5/zZblcLEfiZFHDe/z7RIpEQI3U7Gt4lsJB238DB1XtMiTCckS+5neJ1h46FltoiLvtWJDjrXZfq87LEWmkm7Lv50sJEbCAQPEbsg9jjkEUAQhjKTcEitIES256ISUiTq5PfpCD2DlcJApXy+nl1Dnb6B3HxBSujRvf/MuoHc87KFqHFnztyxXHIDTbuGxluiBGw7k+S+gVr5l4bb05g8vku4nBsIJC8/kEOUbWcWLFiAW265BU1NTTjssMNw2223YcyYMeS+d999N372s5/h1VdfBQCMGjUKc+bMqbh/JdIrCHM1na8ubFUKGw4FK64zF7BaOI6VwnoCk3devebXvHMELSrREX9RtI6h0Cpc0SrwiL+Y0HE54xYKBngC0HU3MM2h4DgRVch48eLFaGxsxMKFCzF27FjMnz8fkyZNwvr16zFw4EDf/qtWrcJ3vvMdjB8/HvX19bjpppvwta99DX/9618xZMgQ9nmdEIQ2lHBnY+pul0+5gaR7xwwPk8JRYyzlBjLdOqMhY+6xpFKY7SQGdc2S0jdQby7mKoWB4GLSpUphCQW7Lf50kFCwXaIShPPmzcN5552HadOmAQAWLlyIJUuW4L777sOMGTN8+//iF78o+fc999yDX//611ixYgXOOuss9nmtVxl3KeFZs2bhpZdewmGHHYZJkyZh06ZN5P5dSvjZZ5/F6tWrMXToUHzta1/D+++/H/HMBUEQBEEQzNHW1oYXX3wREydOLG7LZDKYOHEiVq9ezTrG9u3b0d7ejv79+1d1busOoS0l7FuphHLlqGIRdni4jjdWJzeQGfblVhQH7UPIP5ZvEztHLymVwm2UkxTweK7nBlK4XClc8RyM/MCk5Aayj2chN1D3HByicANdXkIO0MsNNOkIulhlrNOYuqWlpWR7XV0d6ur8OmHz5s3I5/MYNGhQyfZBgwZh3bp1rHNeccUVGDx4cImo5GBVEHYp4ZkzZxa3mVbCra2taG1tLf67+FB8axlriL8apviLoE1M2CuJ6OQG0kLPt8l4biC1XzshzPgFJOHmBnK38cfZyQ2USmGzxG1NYe2xJlvFWBITSS4MCVP8uYpOUcnQoUNLts+aNQvXXXedmYl148Ybb8TDDz+MVatWob6+vqqxVgVhFEp47ty5mD17tv9/1ORKRSC3TQwl/iy1iWFvYzuJPW/jjtOq4k1ZbmBQ15AStVHkBlJEkRuoIxDCXknE9dxA0+cwnRtoQ/zFsTBEqoLtopND+O6776JPnz7F7ZQ7CAADBgxANptFc3Nzyfbm5mY0NDTs9Fw/+clPcOONN+KZZ57BoYceWt1E4UDIWAeOEp45cyYaGxuL/25paelU6plsqRijHEJuf8EI2sSEHQrmbotjm5i4hYK5+9kKBeuIXwrTAo59DoOtYmwtIRfF2stRhIK1xjJEUlIKQ2yRRjeQolDwqn5Pd+3fp0+fEkFYiVwuh1GjRmHFihU4+eSTPztGAStWrMD06dMrjrv55ptxww034KmnnsLo0aOrmmMXVgVhFEq4Upye4xC6tGKIjvgzuWqITpsYUsBF0CaGH0Y2GwqmRKJJcZaUULDpNjFxCwVTWMuX1HFho8gDdET8SR5g9MdygULeq/r9G+T93tjYiKlTp2L06NEYM2YM5s+fj23bthVrLc466ywMGTIEc+fOBQDcdNNNuPbaa/HQQw9h+PDhaGpqAgD07t0bvXv3Zp/XqiC0qYSRqSkRbWQomMwrZDqJhHDMG84DJPfTGcsQiVwhRQo43xa3QsF0eJg3F9MiiXPeWIo/R9rEdG4jzhvweKZDwZIH6A4i/uwcr/ye6NwjU+g4hNVw+umn48MPP8S1116LpqYmjBw5EsuWLSum173zzjvIdDOY7rzzTrS1teGUU04pOU61eYrWQ8a2lLAgCIIgCIKLTJ8+vaIxtmrVqpJ/b9y40cg5rQtCW0rYy+ZKXUFLhSF51c4ba7AwpJpt5e6fTtNondCt6bFhu3eVjqeTa1judNpaQo7tGhoOe7riBgJ+R1AKQ0JwAw06c6YLQ8QNDNcNdJWoHEJbWBeEgB0lXB4yZheVECFjvcIQd7ZxxJROj8AoBBxXXFHn0BFrpit+OfmBLoWHzYdzdULLxH4O9wm0VRgS9oohVY11WPzpIOLPT1zEH4UIwqRSXlSiURVMuXx0viAvh5DtGmpt821iuXA6hSE6/ftsOYQ6hSFckdjGbGNTPheXcgMpTOfQubQkG+dY1qqnLTWIdnnJOHH+7BwvzuKPQhWqLypRIghjgE8QEs6fRwkJrljziz92UYnh/oLcJtGcbUlpEK1TLGKr0KRcAJqem+kega4XhlAELQ5JW1VwEsSfDiL+aJImAMsRhzCplPch1GgJo+XoaVQF5wuU6Awu/jgrhJgWf6YbROsIoihCxtzVWtr9j1tjpZLg25Ii/sJuEm08H0/En9HQrzh/8RN/Ti5dl3BB6FdBgiAIgiAIQqpIr0OYrSnJG6TCufkC4fyxw7SEu0i6fDrnCO4GBl0yznSPQNOOI9e947qBJquCK+8XbJvp8LC4gXruGgdblcfk2JAbRANuLQ/HmkcM3cC4OX8ULrh/HAoF+ruppzFxIbWCMK86SkQbV5hR+0Uh/joKvKpgHfEXuKjEcMGHzlrBxtf3ZReaBBesQYtDdMRfOyWQUlYVbFJM2WoJY1r8kcdzXPwFFRMi/sInbvPtiahWKrFFegVhoaNEoIn4YxaVOFTtq+MQ6jiJpsVfULfS9eXiXBd/XIKuZRxJ3h7zeOR+Iv5450yI+IvC+UuaACwn6TmEqRWEBeRLBJqIv+B9CHXEn61+gPxwa3C3Ukf8cdw/rargSMSf2bGm+wGanF8sm0GnXPzpCD3yeAkWf0kXelxEECaUfKG9RMiJ+DMbMjYt/qJwCIP2AwTMiz/OWK7462gnKugt5QFScMUfF1s9AYMenxwr4o8NR/xItS8fG+KP83xMC/gg5PMeMlV+N+l8l0WNVBkLgiAIgiCknPQ6hGVFJTpuoOl+gK64gdQ2cQPtuIHUWNfdwEgaM1tyA0PvQyhuID0Xh91A150/ClfdQFcpqAAhY+K321VSLAjbSxpKs1vHGG7/0pr3/4CbXg2ErDI1mEPoUhFIUsVf51zKxHkMxV8UVcGRVEszqozJcSL+6PNaKPoQ8RfCOWIs9jioADmEsnRdDFCqUCLudMQfJerYgsvSUnBUsQTnHC6JP24FcBLEH+AXgCL+7Ig/amwkFcUi/ni5Zg7l1InQSxZSVJJQfH0IieXnaGEWP/HHXQeYcw5bjZ9dav9iQ/xR+4n4syP+uMezJf7iFuIF7IR5XXf5bFX2uiL+fEvXKfvzkj6ECaUzZLzj8ulwLlOsRZDfx83T469bzBNTHEHouvjjrAtc3Vx4Y8OuAjYt/rjz0Gn/IuKPGux/o6Td5as41oL4E5cvfOLS1kYcwoTSmveQ6/YlHrfiDoCfG2iywCOKnD/qGrircrgk/rRy/BgiiTuuo92Oy0cRhfjTyckLO4cwyeIvCcUccXT5XBJ6cRF2Qelcuq5aQRjSZEJA2s4IgiAIgiCknNQ6hG35TEkxiEs5f0ELPoDwiz6iqPbVcf70ViXxbyNd2AjWAaaO19FRXlTin5vrOX8UUeQVss8bdM4pK/iImxsozp8eSXf+uEjIOKG0FjzUdntQSc35qzSXoCFjnZw/nWpf00KPezzTIU69MLK5edgq+KBwSfyZLPpwqeAjipw/W2Ffk2IvyUJPRJ0+UlSSUFoLGdR2cwht9PQDomnqbDLvj7vebxQFH1FU+3KcumrGmqz4dV382corJM9hoejDdfEXhcvnSiWvCD23YT1XB65TGlMnlLa8V1JFzHXgwm7rAkRVuMHbrzzMG8cQr42Cj+qO559f0N53Lom/KKp9XWnqrCP+xOVz2+UToUcTRfW1a0jIOKF0FLwScZcml6/SfhynL4qefi6FePVcQ2J+BsO8LrlyTuf3AVrijyP2xOWj0RKxIQsOEXrpFHU6qAAhYxWjkLFUGQuCIAiCIKSc1DqErXkPNT30IdRxCMMu7qi0X9jVvS6FfY0XbcRslQ8p+Kh0QI2wbMDQahzDvq6EeAGzrlkUzp+4fOmkUPDgScg4ebTlM6jpoe2M6/l9Npo664R9+cvKuRPOdEX8Udsk54/Glepe02HfOOb3xS2fz5bQS7uo49x3J0R4QVVf3OLCvJmkVhC25oFstx+nJOT3VRprso2LuHzhC71K20zOLW0un40cP8nvqzBWRF3scEKMOUCmoKq+FypG9y69grAAZLo5YkkWeiZdSNMtXHQEDF84EvMz3DTZRnWvuHz8sTr7mRR/IvT0sCFM4ij0RMCFg5dX7O+h7mPiQmoF4Sd5AN0EIVuEsV1DXt6eTmNmG2v0SjjX/LVRhB0yZp8z5S4fwBM6IvSqOF7IYkUEnBAWXgCHsBCjZytVxoIgCIIgCCkntQ5hax7wAjiEcQvxAmbDvOL80dgo8Ehb2NdWD7+wE97F+aNx2ekTRy+dZFT1DmFGxee9klpB+KkvZBy8OtfW2rs28vniKPQobAlb9lwC5hCSx48g7BtFfp+tqt2gP/7s+TpUiRuF0LEh9ETA0e+LQtb/PWCrYXdPuDAvr6Cqfv+6/IdNOSkWhB7Q7YcyinV2XRZ63OOJ0NM8r8FzpC2/zxWhBzALTcS90yIJIs4FEbMzXJ+fa2Ty1d+zDKEjXCW1grC1A0DHjn+7HKatuJ+VQgZinAg941XLnONxw7lpE3o6Y8N29Vxy79Ik6kT4CCYI0nbGlc8Ah9QKwk8LgOqm3JMi9HQcPIqgDiE3R488Z8qFXqWx/jdf+OHcKPL2XHH0dM5hS+glWdSJiBNcI+mCUKqMBUEQBEEQUk5qHcLWfM8OoesuH4XpkG45Ya++AbhTtKF7PO5Yumq3ZydNnD/zzl9QV8oll08cveBQz0Jl/J/ZOBUKJAUX7rkUlSSU1rYMVO0OgzQpoo4+XvCQLqv9iaVwLjnW+cbMvFh90Hw+l3L5XMnbA+yEdF1aA5g8h+MCzpUfUlfmIdgn6SHj1ArCtrYsUJst/ttGPl4156Cw1UvPhXNqnzehFbpJFnpxy9NLiqgTQSQInWTyKkCVcXw+P6kVhO1tWaCmmyC05Mpxx7ossJIi6qIo0uDu5/L6ubaEXhybP/uOJQKOhOucC+5QyKSvBEEaUyeUttYMVHbnIeM4Cjj2OQKKLp1zco5f1fE0RB2F6+vnhh0y1jleFO6dK6IOMCvsXBJwIswELibfK7ERlwFyCLkrO7lAagVha1sWhWx3h9CsU0XuZ1g4mp4LZz/Tx+cvgxY8944+np1CC+5cgu6XZPfOpUbP5UTT/iUZYi1OOVU2KRDFLEmF89524f2f9BzCmMhyQRAEQRAEISzS6xB+UoN8D5cfN/cukvM65Ojp7OeyyweYXRpNXD4ak66eC+5FF3FyJITKRPEc0+RCmkCKShJKW2sGhUx2p/uww54aY42HYMnB5kScLQFHziUBoq6q8zKuI8mizny/vnBFXFKEWZx+0MKgkE2uaDL9Hk26wJQ+hAmlvS0bqiDUSSQl89uIYgkK04It6DitAoA4ik5H3LokizrTAs4VwZZ2weU6tp5PHIUo5zMVZ9GYKRSq/h5yKXrQE6kVhGhTQGbnb15KhHEFF3k8h6okTbtrvnERXGsUVacuV89GcQ1pE3Aui7M4OQ1xgFqBxCVMvxddEZhhmxFhkvSiktQKwpq2PGq8fM87MjAeyjL8RWBDiNpqEeKygKvmHOTYgNfGdw3dFnDSw08wSRTP2iXRqfP5cUVM2ibpOYRSZSwIgiAIgpByUusQ1rYWUAMzDiFFFDax+SR7g1WXDrmccXP0AG6VsTuOnssVwLrEKeQjuJWjZvp9bMtx5H6+k+4keipAUUnAlUoWLFiAW265BU1NTTjssMNw2223YcyYMRX3f+SRR3DNNddg48aN+PznP4+bbroJJ5xwQlXnTK8gNBgy5uKSdRz2j5zLYhWIZ16dy4Ld9edtGtfn5wo2xFmS27VwP2cuC8c4i8aocggXL16MxsZGLFy4EGPHjsX8+fMxadIkrF+/HgMHDvTt/8ILL+A73/kO5s6di69//et46KGHcPLJJ+Oll17CiBEj2Of1lLK/0F6USrilpQV9+/bFxHN/gdrcLqYuwRlc+aGKY76XK2INMHv/4ijWXHkfC/HDJYeQwqX5uZTjWE65cGxv247/75fTsGXLFvTp0yfSuXTphlO+eTdqa6vTDe3t2/HoY+dVNe+xY8fi8MMPx+233w4AKBQKGDp0KC6++GLMmDHDt//pp5+Obdu24cknnyxuO+KIIzBy5EgsXLiQPVfrDqEtJVzbmketitYhTBO2Su1dFmaA6WbIyRBmLjnnQviE7RC53ltPZ36m56LzfRS2mCz/XnDheyKTV8j00J2EGgN0isru1NXVoa6uzrd/W1sbXnzxRcycOXPHMTIZTJw4EatXrybPsXr1ajQ2NpZsmzRpEh5//PGq5mpdEM6bNw/nnXcepk2bBgBYuHAhlixZgvvuu49UwrfeeiuOO+44/OhHPwIAXH/99Vi+fDluv/32qpRwrrUDtarDzEUwibNVXi0ufHi7iGad2fifw9YzcylfUPBj+oc/7PeZ6e9Z7ucuCufPpblQn1uXHUcjFAJ8X33mjQwdOrRk86xZs3Ddddf5dt+8eTPy+TwGDRpUsn3QoEFYt24deYqmpiZy/6ampqqmalUQRqGEW1tb0draWvx3uUoXBEEQBEEIk3fffbckZEy5g7axKgijUMJz587F7Nmz/f9j+8dAu5mQsUs5IVxCn3MMXR8dB477TtJxRwK/W5nR+ygSKCQ3MHzi+H1ktAFae/ChtqI4tp6ZjfMqL9g529s/6Rxvsewh37YdHVV+h+U7Oufdp08fVg7hgAEDkM1m0dzcXLK9ubkZDQ0N5JiGhoaq9q+E9ZBx2MycObPEUXz//fdx0EEH4b+f/L7FWQmCIAiCUC1bt25F3759Iz1nLpdDQ0MDfv30pYHGNzQ0IJfLsc81atQorFixAieffDKAzqKSFStWYPr06eSYcePGYcWKFbj00h3zW758OcaNG1fVPK0KwiiUcHniZu/evfHuu+9it912w9atWzF06FCflStES0tLizwHB5Dn4AbyHNxAnoMbdD2Hd955B57nYfDgwZHPob6+Hhs2bEBbW1ug8blcDvX19ez9GxsbMXXqVIwePRpjxozB/PnzsW3btmKtxVlnnYUhQ4Zg7ty5AIBLLrkEEyZMwE9/+lOceOKJePjhh/HnP/8Zd911V1XztCoIbSjhTCaDvfbaCwDgfWZdc61cIVzkObiBPAc3kOfgBvIc3KBv375Wn0N9fX1Vok6H008/HR9++CGuvfZaNDU1YeTIkVi2bFkxXe6dd95BJrMjz2L8+PF46KGHcPXVV+PKK6/E5z//eTz++ONVdV4BHOhDuHjxYkydOhX/9V//VVTCv/rVr7Bu3ToMGjTIp4RfeOEFTJgwATfeeGNRCc+ZM6fqtjPAjt5CNvoaCTuQ5+AG8hzcQJ6DG8hzcAN5DtFhPYfQlhIWBEEQBEEQOrEuCAFg+vTpFUPEq1at8m079dRTceqpp2qft66uDrNmzXKy/DtNyHNwA3kObiDPwQ3kObiBPIfosB4yFgRBEARBEOxisvuTIAiCIAiCEENEEAqCIAiCIKQcEYSCIAiCIAgpJ/GCcMGCBRg+fDjq6+sxduxYrFmzZqf7P/LIIzjggANQX1+PQw45BEuXLo1opsmmmudw991346ijjsLuu++O3XffHRMnTuzxuQk8qv08dPHwww/D87xiv1BBj2qfw0cffYSLLroIe+65J+rq6vCFL3xBvpsMUO1zmD9/Pr74xS+iV69eGDp0KC677DJ8+umnEc02efzud7/D5MmTMXjwYHieh8cff7zHMatWrcKXv/xl1NXVYf/998cDDzwQ+jxTg0owDz/8sMrlcuq+++5Tf/3rX9V5552n+vXrp5qbm8n9n3/+eZXNZtXNN9+sXnvtNXX11Ver2tpa9corr0Q882RR7XM444wz1IIFC9TLL7+sXn/9dXX22Wervn37qvfeey/imSeLap9DFxs2bFBDhgxRRx11lPrGN74RzWQTTLXPobW1VY0ePVqdcMIJ6rnnnlMbNmxQq1atUmvXro145smi2ufwi1/8QtXV1alf/OIXasOGDeqpp55Se+65p7rssssinnlyWLp0qbrqqqvUb37zGwVAPfbYYzvd/6233lK77LKLamxsVK+99pq67bbbVDabVcuWLYtmwgkn0YJwzJgx6qKLLir+O5/Pq8GDB6u5c+eS+5922mnqxBNPLNk2duxY9X/+z/8JdZ5Jp9rnUE5HR4fabbfd1KJFi8KaYioI8hw6OjrU+PHj1T333KOmTp0qgtAA1T6HO++8U+27776qra0tqimmgmqfw0UXXaS++tWvlmxrbGxURx55ZKjzTAscQXj55Zergw8+uGTb6aefriZNmhTizNJDYkPGbW1tePHFFzFx4sTitkwmg4kTJ2L16tXkmNWrV5fsDwCTJk2quL/QM0GeQznbt29He3s7+vfvH9Y0E0/Q5/DjH/8YAwcOxHe/+90oppl4gjyHJ554AuPGjcNFF12EQYMGYcSIEZgzZw7y+XxU004cQZ7D+PHj8eKLLxbDym+99RaWLl2KE044IZI5C/IbHTZONKYOg82bNyOfzxdXPOli0KBBWLduHTmmqamJ3L+pqSm0eSadIM+hnCuuuAKDBw/2fREIfII8h+eeew733nsv1q5dG8EM00GQ5/DWW29h5cqVOPPMM7F06VK88cYb+N73vof29nbMmjUrimknjiDP4YwzzsDmzZvxla98BUopdHR04IILLsCVV14ZxZQFVP6NbmlpwSeffIJevXpZmlkySKxDKCSDG2+8EQ8//DAee+yxyBYWF4CtW7diypQpuPvuuzFgwADb00k1hUIBAwcOxF133YVRo0bh9NNPx1VXXYWFCxfanlqqWLVqFebMmYM77rgDL730En7zm99gyZIluP76621PTRCMkFiHcMCAAchms2hubi7Z3tzcjIaGBnJMQ0NDVfsLPRPkOXTxk5/8BDfeeCOeeeYZHHrooWFOM/FU+xzefPNNbNy4EZMnTy5uKxQKAICamhqsX78e++23X7iTTiBBPg977rknamtrkc1mi9sOPPBANDU1oa2tDblcLtQ5J5Egz+Gaa67BlClTcO655wIADjnkEGzbtg3nn38+rrrqKmQy4q+ETaXf6D59+og7aIDEvoNzuRxGjRqFFStWFLcVCgWsWLEC48aNI8eMGzeuZH8AWL58ecX9hZ4J8hwA4Oabb8b111+PZcuWYfTo0VFMNdFU+xwOOOAAvPLKK1i7dm3xddJJJ+GYY47B2rVrMXTo0CinnxiCfB6OPPJIvPHGG0VBDgB/+9vfsOeee4oYDEiQ57B9+3af6OsS6UpWgI0E+Y0OGdtVLWHy8MMPq7q6OvXAAw+o1157TZ1//vmqX79+qqmpSSml1JQpU9SMGTOK+z///POqpqZG/eQnP1Gvv/66mjVrlrSdMUC1z+HGG29UuVxOPfroo+qDDz4ovrZu3WrrEhJBtc+hHKkyNkO1z+Gdd95Ru+22m5o+fbpav369evLJJ9XAgQPVf/zHf9i6hERQ7XOYNWuW2m233dQvf/lL9dZbb6mnn35a7bfffuq0006zdQmxZ+vWrerll19WL7/8sgKg5s2bp15++WX19ttvK6WUmjFjhpoyZUpx/662Mz/60Y/U66+/rhYsWCBtZwySaEGolFK33Xab2nvvvVUul1NjxoxRf/jDH4r/b8KECWrq1Kkl+//qV79SX/jCF1Qul1MHH3ywWrJkScQzTibVPIdhw4YpAL7XrFmzop94wqj289AdEYTmqPY5vPDCC2rs2LGqrq5O7bvvvuqGG25QHR0dEc86eVTzHNrb29V1112n9ttvP1VfX6+GDh2qvve976l//etf0U88ITz77LPkd33XfZ86daqaMGGCb8zIkSNVLpdT++67r7r//vsjn3dS8ZQSr1sQBEEQBCHNJDaHUBAEQRAEQeAhglAQBEEQBCHliCAUBEEQBEFIOSIIBUEQBEEQUo4IQkEQBEEQhJQjglAQBEEQBCHliCAUBEEQBEFIOSIIBUEQBEEQUo4IQkEQBEEQhJQjglAQhERw9NFH49JLL7U9DUEQhFgiglAQBEEQBCHlyFrGgiDEnrPPPhuLFi0q2bZhwwYMHz7czoQEQRBihghCQRBiz5YtW3D88cdjxIgR+PGPfwwA2GOPPZDNZi3PTBAEIR7U2J6AIAiCLn379kUul8Muu+yChoYG29MRBEGIHZJDKAiCIAiCkHJEEAqCIAiCIKQcEYSCICSCXC6HfD5vexqCIAixRAShIAiJYPjw4fjjH/+IjRs3YvPmzSgUCranJAiCEBtEEAqCkAh++MMfIpvN4qCDDsIee+yBd955x/aUBEEQYoO0nREEQRAEQUg54hAKgiAIgiCkHBGEgiAIgiAIKUcEoSAIgiAIQsoRQSgIgiAIgpByRBAKgiAIgiCkHBGEgiAIgiAIKUcEoSAIgiAIQsoRQSgIgiAIgpByRBAKgiAIgiCkHBGEgiAIgiAIKUcEoSAIgiAIQsoRQSgIgiAIgpBy/n+f3MdvLcY6ygAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "N_t, N_x = 100, 256\n", "\n", "t = np.linspace(0.0, 1.0, N_t)\n", "x = np.linspace(0.0, 1.0, N_x)\n", "T, X = np.meshgrid(t, x, indexing='ij')\n", "coords = np.stack([T.flatten(), X.flatten()], axis=1)\n", "\n", "output = model(jnp.array(coords))\n", "resplot = np.array(output).reshape(N_t, N_x)\n", "\n", "plt.figure(figsize=(7, 4))\n", "plt.pcolormesh(T, X, resplot, shading='auto', cmap='Spectral_r')\n", "plt.colorbar()\n", "\n", "plt.title('Solution of Diffusion Equation')\n", "plt.xlabel('t')\n", "\n", "plt.ylabel('x')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "eba27f8f-c733-4114-824d-ace0c208f441", "metadata": {}, "source": [ "We can also visualize the difference between the analytical and the approximated solution." ] }, { "cell_type": "code", "execution_count": 10, "id": "215ff5fe-6bf4-4636-8e3a-f27712af0274", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAqEAAAGGCAYAAABL41hNAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA+ABJREFUeJzsvXt8FeW5Pb723kl2wiVB5BLACKhUtNwqCKJ4p6XeeujRKmoFqUVrRcFYFawIahWVqnjDFK3aWjlaqnL8iqVi1KMt+XkBbWsVai2IRYMihcgtl73f3x802+yZFbL2vJOd4H5XP/mcw/jOzDuXPbPmeZ61nogxxsDBwcHBwcHBwcEhi4i29QQcHBwcHBwcHBxyD46EOjg4ODg4ODg4ZB2OhDo4ODg4ODg4OGQdjoQ6ODg4ODg4ODhkHY6EOjg4ODg4ODg4ZB2OhDo4ODg4ODg4OGQdjoQ6ODg4ODg4ODhkHY6EOjg4ODg4ODg4ZB2OhDo4ODg4ODg4OGQdjoSGjDlz5iASiaQta2howFVXXYWysjJEo1GMHz8eALBt2zb88Ic/RGlpKSKRCKZPn579CbcTrFu3DpFIBD//+c/beioO7QzHHXccjjvuuLaeRqhovN8feeSRtp5KCu39PLfGOWvt6/Dyyy8jEong5ZdfbpXtM7B3UBjo168fzj///NC365DbcCR0D3jkkUcQiURSf4WFhejduzfGjRuHu+++G1988YW0nYceegjz5s3DGWecgV/96le4/PLLAQA333wzHnnkEVx88cV49NFHcd5557Xm4eQc3n33XcyZMwfr1q1r66k47KVYsGBBuyKKDsGwaNEizJ8/v62n0e6xYsUKzJkzB1u2bGnrqTjkCPLaegJ7A2644Qb0798f9fX1qK6uxssvv4zp06fjjjvuwDPPPIMhQ4akxl577bWYMWNG2vovvvgi+vTpgzvvvNO3/IgjjsDs2bOzchy5hnfffRfXX389jjvuOPTr16+tp+MQEM8//3yb7XvBggXo1q2biwDt5Vi0aBHeeecdX7apb9++2LlzJ/Lz89tmYu0MK1aswPXXX4/zzz8fXbp0Sftva9asQTTq4lYO4cKRUAEnnXQSRowYkfr3zJkz8eKLL+LUU0/Fd77zHbz33nsoKioCAOTl5SEvL/20fvrpp74fdOPyQw89NLR5JpNJ1NXVobCwMLRtOjioMMZg165dqd9CWCgoKAh1ew4OjWjMcDm0jHg83tZTcPgKwn3WBMQJJ5yAWbNm4cMPP8RvfvOb1PKm9TiN9UYvvfQS/va3v6XS+o11QmvXrsXSpUtTyxvTxrW1tZg9ezYOOuggxONxlJWV4aqrrkJtbW3aHCKRCKZOnYrHHnsMX//61xGPx7Fs2TIAwIYNG/CDH/wAPXv2RDwex9e//nU89NBDaes3zuO3v/0tbrrpJuy3334oLCzEiSeeiH/84x++Y37ttddw8sknY5999kHHjh0xZMgQ3HXXXWljVq9ejTPOOANdu3ZFYWEhRowYgWeeeSajc3vnnXeib9++KCoqwrHHHot33nnHN6al/TzyyCP43ve+BwA4/vjj0859eXk59t13XxhjUuMvvfRSRCIR3H333allGzduRCQSwf33359apl4bAPjNb36D4cOHo6ioCF27dsWECRPw0UcfpY057rjjMGjQILz77rs4/vjj0aFDB/Tp0we33XabdK4efvhhnHDCCejRowfi8TgOPfTQtPk2ol+/fjj11FPx/PPPY9iwYSgsLMShhx6Kp556Km1cYwnKK6+8gosuugj77rsviouLMXHiRPz73/+m2/zDH/6AESNGoKioCL/4xS8AAP/85z/xve99D127dkWHDh1wxBFHYOnSpal1Gz/cJk6cmLbNP/7xj4jFYrj66qvTzlHTWsWm9+3111+PPn36oHPnzjjjjDOwdetW1NbWYvr06ejRowc6deqEyZMn+66Pct769euHv/3tb/i///u/1P3TdB5btmzB9OnTUVZWhng8joMOOgi33norkslk2na2bNmC888/HyUlJejSpQsmTZokpzs3b96Mn/zkJxg8eDA6deqE4uJinHTSSfjzn/+cNi7T3/LChQtx4IEHoqioCCNHjsSrr74qzQfI/J774x//iJEjR6KwsBAHHHAAfv3rXwc6RjaPSCSCt956y/ffbr75ZsRiMWzYsAHHHXccli5dig8//DB1HRszI83VhK5evRpnnnkmunfvjqKiIhx88MH46U9/mvrvH374IX784x/j4IMPRlFREfbdd19873vfC1z688UXX2D69Ono168f4vE4evTogW9+85tYtWpV2rjFixennindunXD97//fWzYsGGP295T3WskEsGcOXMA7H53XXnllQCA/v37+95LrCa0pd85kPm96ZBjMA7N4uGHHzYAzBtvvEH/+0cffWQAmDPOOCO1bPbs2abxtG7bts08+uijZuDAgWa//fYzjz76qHn00UdNdXW1efTRR023bt3MsGHDUsu3bdtmEomE+da3vmU6dOhgpk+fbn7xi1+YqVOnmry8PPNf//VfafsHYA455BDTvXt3c/3115v77rvPvPXWW6a6utrst99+pqyszNxwww3m/vvvN9/5zncMAHPnnXem1n/ppZcMAPONb3zDDB8+3Nx5551mzpw5pkOHDmbkyJFp+3r++edNQUGB6du3r5k9e7a5//77zWWXXWbGjh2bGvPOO++YkpISc+ihh5pbb73V3HvvveaYY44xkUjEPPXUU3s812vXrjUAzODBg02/fv3Mrbfeaq6//nrTtWtX0717d1NdXZ3Rfj744ANz2WWXGQDmmmuuSTv3Tz31lAFg/vrXv6a2OXToUBONRtOu5eLFiw0A88477xhjTEbX5mc/+5mJRCLmrLPOMgsWLDDXX3+96datm+nXr5/597//nRp37LHHmt69e5uysjIzbdo0s2DBAnPCCScYAOa5557b4zkzxpjDDz/cnH/++ebOO+8099xzj/nWt75lAJh77703bVzfvn3N1772NdOlSxczY8YMc8cdd5jBgwebaDRqnn/++dS4xnt+8ODB5uijjzZ33323ueSSS0w0GjXHHHOMSSaTads86KCDzD777GNmzJhhKioqzEsvvWSqq6tNz549TefOnc1Pf/pTc8cdd6TOb9P7YN68eQaA+d///V9jzO7fy4EHHmgOPfRQs2vXrrRzdOyxx6b+3XjfDhs2zIwePdrcfffd5rLLLjORSMRMmDDBnHPOOeakk04y9913nznvvPMMAHP99ddnfN6efvpps99++5mBAwem7p/Gc7V9+3YzZMgQs++++5prrrnGVFRUmIkTJ5pIJGKmTZuW2kYymTTHHHOMiUaj5sc//rG55557zAknnGCGDBliAJiHH354j9f3jTfeMAceeKCZMWOG+cUvfmFuuOEG06dPH1NSUmI2bNjgOyfKb/nBBx80AMyRRx5p7r77bjN9+nTTpUsXc8ABB6Sd5+aQyT138MEHm549e5prrrnG3Hvvveawww4zkUgk9ZvK5BgbnxGN56ympsYUFRWZK664wjfHQw891JxwwgnGmN3PrmHDhplu3bqlruPTTz9Nt2mMMX/+859NcXGx2Xfffc3MmTPNL37xC3PVVVeZwYMHp8YsXrzYDB061Fx33XVm4cKF5pprrjH77LOP6du3r9m+fbvvurz00kt7PKfnnHOOKSgoMOXl5ebBBx80t956qznttNPMb37zm9SYxt/m4Ycfbu68804zY8YMU1RU5HumNH0HNXeMjQBgZs+enTrus88+O/WeaPpeMmb39Zw0aVJqXfV3nsm96ZB7cCR0D2iJhBpjTElJifnGN76R+rf3AWDM7pfo17/+dd+6ffv2NaecckraskcffdREo1Hz6quvpi2vqKgwAMyf/vSn1DIAJhqNmr/97W9pYy+44ALTq1cvs2nTprTlEyZMMCUlJWbHjh3GmC8fDocccoipra1NjbvrrrvSSFpDQ4Pp37+/6du3b9rDzhiTRkpOPPFEM3jw4DQCkUwmzZFHHmkGDBjgO/6maHxQFhUVmX/961+p5a+99poBYC6//PKM99NIIr0vgE8//dQAMAsWLDDGGLNlyxYTjUbN9773PdOzZ8/UuMsuu8x07do1dYzqtVm3bp2JxWLmpptuShv317/+1eTl5aUtP/bYYw0A8+tf/zq1rLa21pSWlprTTz99j+fMGJO6lk0xbtw4c8ABB6Qt69u3rwFgnnzyydSyrVu3ml69eqXdv433/PDhw01dXV1q+W233ZZGGJtuc9myZWn7mj59ugGQdp6++OIL079/f9OvXz+TSCSMMbtJ/ZgxY0zPnj3Npk2bzCWXXGLy8vJ8v7fmSOigQYPS5nj22WebSCRiTjrppLT1R48ebfr27RvovH3961+nxOzGG280HTt2NH//+9/Tls+YMcPEYjGzfv16Y4wxS5YsMQDMbbfdlhrT0NBgjj76aImE7tq1K3W+GrF27VoTj8fNDTfckFqm/pbr6upMjx49zLBhw9LGLVy40ACQSGim99wrr7ySWvbpp5+aeDyeRhzVY2Rk6uyzzza9e/dOW3/VqlW+caeccorvHmhum8ccc4zp3Lmz+fDDD9PGNn3WsXNQVVXl+y2rJLSkpMRccsklzf73xus2aNAgs3PnztTyZ5991gAw1113XWpZUBJqzJcfhmvXrvWN9ZJQ9Xeu3psOuQmXjrdEp06dZJW8gsWLF+OQQw7BwIEDsWnTptTfCSecAAB46aWX0sYfe+yxaXWlxhg8+eSTOO2002CMSdvGuHHjsHXrVl+KZ/LkyWl1d0cffTSA3akWAHjrrbewdu1aTJ8+3Vfb2lh6sHnzZrz44os488wz8cUXX6T2+fnnn2PcuHF4//33W0wbAcD48ePRp0+f1L9HjhyJUaNG4bnnngttP927d8fAgQPxyiuvAAD+9Kc/IRaL4corr8TGjRvx/vvvAwBeffVVjBkzJnWM6rV56qmnkEwmceaZZ6aNKy0txYABA3zXsFOnTvj+97+f+ndBQQFGjhyZOv97QtP6y61bt2LTpk049thj8c9//hNbt25NG9u7d29897vfTf27Mc3+1ltvobq6Om3shRdemCbWuPjii5GXl5e6Do3o378/xo0bl7bsueeew8iRIzFmzJi0Y7zwwguxbt06vPvuuwCAaDSKRx55BNu2bcNJJ52EBQsWYObMmWn113vCxIkT0+Y4atQoGGPwgx/8IG3cqFGj8NFHH6GhoSG1LJPzxrB48WIcffTR2GeffdKu8dixY5FIJFL31nPPPYe8vDxcfPHFqXVjsRguvfRS6Rjj8XhKDJJIJPD555+jU6dOOPjgg32/Y6Dl3/Kbb76JTz/9FD/60Y/SxjWWCyjI5NwdeuihqTkAu397Bx98cNq9nekxNsXEiRPx8ccfp/2mHnvsMRQVFeH000+XjqcpPvvsM7zyyiv4wQ9+gP333z/tvzW1PWp6Durr6/H555/joIMOQpcuXVqcM0OXLl3w2muv4eOPP6b/vfG6/fjHP06rYT3llFMwcOBAXwo8G1B/541o6d50yE04YZIltm3bhh49eoS2vffffx/vvfceunfvTv/7p59+mvbv/v37p/37s88+w5YtW7Bw4UIsXLhQ2ob3YbvPPvsAQKoG8IMPPgAADBo0qNl5/+Mf/4AxBrNmzcKsWbOa3W9TgskwYMAA37Kvfe1r+O1vfxvqfo4++ugUoXr11VcxYsQIjBgxAl27dsWrr76Knj174s9//jPOOeec1DrqtXn//fdhjKHHAsCnxN1vv/18vn777LMP/vKXv+zxGIDdBHr27NmoqqrCjh070v7b1q1b04jFQQcd5NvP1772NQC768ZKS0tTy71z79SpE3r16uWrefPef8DuerlRo0b5lh9yyCGp/954Lx144IGpWrRBgwY1e00ZvPdt47GWlZX5lieTSWzduhX77rsvgMzOG8P777+Pv/zlLy3eCx9++CF69eqFTp06pf33gw8+uIWj241kMom77roLCxYswNq1a5FIJFL/rfFYmqKl3/KHH34IwH998/PzccABB0hzyuTceefTOKem9cWZHmNTfPOb30SvXr3w2GOP4cQTT0QymcT//M//4L/+67/QuXNn6XiaopEQ7elZBwA7d+7E3Llz8fDDD2PDhg1p9eXKR4wXt912GyZNmoSysjIMHz4cJ598MiZOnJi6Jo3Xjd03AwcOxB//+MeM92mLTH7nQMv3pkNuwpFQC/zrX//C1q1bcdBBB4W2zWQyicGDB+OOO+6g/937gvUqkRtFEd///vcxadIkuo2mllLA7sgMQ9MHa0to3O9PfvITX2SsEWGcp7D2M2bMGDzwwAP45z//iVdffRVHH300IpEIxowZg1dffRW9e/dGMplMi+Ko1yaZTCISieD3v/89PbdeQhL0/H/wwQc48cQTMXDgQNxxxx0oKytDQUEBnnvuOdx5550+gUxrIAwlfKMF08cff4zPP/88jQzvCc2dt5bOZxjnLZlM4pvf/Cauuuoq+t8byb0tbr75ZsyaNQs/+MEPcOONN6Jr166IRqOYPn06nWcYv+U9IdNzp8wn02P0bv+cc87BAw88gAULFuBPf/oTPv7447TMQmvg0ksvxcMPP4zp06dj9OjRKCkpQSQSwYQJEwL97s4880wcffTRePrpp/H8889j3rx5uPXWW/HUU0/hpJNOspprc8b1Tcl+NtDa96bD3glHQi3w6KOPAkCzZCgIDjzwQPz5z3/GiSeeGKjrRffu3dG5c2ckEgmMHTs2tDkBwDvvvNPsNhu/2PPz863225gKb4q///3vKTVrJvvZ0/lrJJfLly/HG2+8kfJ2PeaYY3D//fejd+/e6NixI4YPH55aR702Bx54IIwx6N+/f2hkhOH//b//h9raWjzzzDNpUQZvur8RjVHkpnP/+9//DgA+H9X3338fxx9/fOrf27ZtwyeffIKTTz65xXn17dsXa9as8S1fvXp16r83oqKiAsuXL8dNN92EuXPn4qKLLsL//u//trgPG2Ry3pq7zgceeCC2bdvW4j3Yt29fVFZWYtu2bWkfH+z8MPzud7/D8ccfj1/+8pdpy7ds2YJu3bpJ2/DOB9h9fRvLSIDdKeW1a9di6NChe1w/03tOge0xTpw4Ebfffjv+3//7f/j973+P7t27+57J6rO08fnCHDm8c540aRJuv/321LJdu3ZZmbz36tULP/7xj/HjH/8Yn376KQ477DDcdNNNOOmkk1LXbc2aNWnXrXFZ09+UF40RR+/cGqOrTZHJOyeT37mDQ3NwNaEB8eKLL+LGG29E//79ce6554a23TPPPBMbNmzAAw884PtvO3fuxPbt2/e4fiwWw+mnn44nn3ySPkg/++yzjOd02GGHoX///pg/f77vQdb4FdujRw8cd9xx+MUvfoFPPvkk8H6XLFmSVtP5+uuv47XXXktFAzLZT8eOHQH4H77A7jRyYwOB+vp6HHXUUQB2k9MPPvgAv/vd73DEEUekeb6q1+a///u/EYvFcP311/u+8o0x+Pzzz6Vz0RIaIwveVODDDz9Mx3/88cd4+umnU/+uqanBr3/9awwbNswXfVy4cCHq6+tT/77//vvR0NAgRWVOPvlkvP7666iqqkot2759OxYuXIh+/fqlapjXrl2LK6+8EqeffjquueYa/PznP8czzzzjs/AJG5mct44dO9L758wzz0RVVRX+8Ic/+P7bli1bUvWnJ598MhoaGtIsjBKJBO655x55rt57aPHixVJ9NcOIESPQvXt3VFRUoK6uLrX8kUcekQhUpvecAttjHDJkCIYMGYIHH3wQTz75JCZMmODzau7YsaOUJu/evTuOOeYYPPTQQ1i/fn3af2s6Rzbne+65J1B0MZFI+ObWo0cP9O7dO2UtNmLECPTo0QMVFRVpdmO///3v8d577+GUU05pdvvFxcXo1q1bqk65EQsWLPCN3dMz0wv1d+7gsCe4SKiA3//+91i9ejUaGhqwceNGvPjii1i+fDn69u2LZ555JlSz4/POOw+//e1v8aMf/QgvvfQSjjrqKCQSCaxevRq//e1vU56Me8Itt9yCl156CaNGjcKUKVNw6KGHYvPmzVi1ahVeeOEFbN68OaM5RaNR3H///TjttNMwbNgwTJ48Gb169cLq1avxt7/9LfUivu+++zBmzBgMHjwYU6ZMwQEHHICNGzeiqqoK//rXv1r0/QN2p9LHjBmDiy++GLW1tZg/fz723XfftLSnup9hw4YhFovh1ltvxdatWxGPx1P+hsBuwvn4449j8ODBqWjBYYcdho4dO+Lvf/97Wj0ooF+bAw88ED/72c8wc+ZMrFu3DuPHj0fnzp2xdu1aPP3007jwwgvxk5/8JKNrwPCtb30LBQUFOO2003DRRRdh27ZteOCBB9CjRw9K0L/2ta/hggsuwBtvvIGePXvioYcewsaNGymBqKurw4knnogzzzwTa9aswYIFCzBmzBh85zvfaXFeM2bMwP/8z//gpJNOwmWXXYauXbviV7/6FdauXYsnn3wS0Wg0JSAqKipKEbSLLroITz75JKZNm4axY8eid+/e1ueIIZPzNnz4cNx///342c9+hoMOOgg9evTACSecgCuvvBLPPPMMTj31VJx//vkYPnw4tm/fjr/+9a/43e9+h3Xr1qFbt2447bTTcNRRR2HGjBlYt25dyptVrRs89dRTccMNN2Dy5Mk48sgj8de//hWPPfaYXL/pRX5+Pn72s5/hoosuwgknnICzzjoLa9euxcMPPyxtM9N7TkEYxzhx4sTUb4ql4ocPH44nnngC5eXlOPzww9GpUyecdtppdFt33303xowZg8MOOwwXXngh+vfvj3Xr1mHp0qV4++23U3N+9NFHUVJSgkMPPRRVVVV44YUXWqxhZfjiiy+w33774YwzzsDQoUPRqVMnvPDCC3jjjTdSkdb8/HzceuutmDx5Mo499licffbZ2LhxI+666y7069cv1Qq6Ofzwhz/ELbfcgh/+8IcYMWIEXnnllVQWxHueAOCnP/0pJkyYgPz8fJx22mkpctoUyu/cwaFFZEuGvzei0a6m8a+goMCUlpaab37zm+auu+4yNTU1vnVsLZqM2W3Hceutt5qvf/3rJh6Pm3322ccMHz7cXH/99Wbr1q2pcQCatfXYuHGjueSSS0xZWZnJz883paWl5sQTTzQLFy5MjWm0zli8eHHaus1Zevzxj3803/zmN03nzp1Nx44dzZAhQ8w999yTNuaDDz4wEydONKWlpSY/P9/06dPHnHrqqeZ3v/sdnad3n/PmzTO33367KSsrM/F43Bx99NHmz3/+s2+8up8HHnjAHHDAASYWi/msUu677z4DwFx88cVp64wdO9YAMJWVlb79qtfGGGOefPJJM2bMGNOxY0fTsWNHM3DgQHPJJZeYNWvWpMY0d29MmjSJWsp48cwzz5ghQ4aYwsLClL/qQw895LNZabzX/vCHP5ghQ4aYeDxuBg4c6Lv2jff8//3f/5kLL7zQ7LPPPqZTp07m3HPPNZ9//nna2ObuX2N2X58zzjjDdOnSxRQWFpqRI0eaZ599NvXfG+1ZmlpGGWPM+vXrTXFxsTn55JPTzhGzaGpu7l6Lp8bf5GeffZbxeauurjannHKK6dy5s8/C6IsvvjAzZ840Bx10kCkoKDDdunUzRx55pPn5z3+eZh31+eefm/POO88UFxebkpISc95555m33npLtmi64oorTK9evUxRUZE56qijTFVVlXxOmvstL1iwwPTv39/E43EzYsQI88orr/i22Rwyvee88O5HPcY9WQ198sknJhaLma997Wt0ztu2bTPnnHOO6dKliwGQ+m01t8133nnHfPe7303dvwcffLCZNWtW6r//+9//NpMnTzbdunUznTp1MuPGjTOrV6/22RgpFk21tbXmyiuvNEOHDk09W4cOHZqykGuKJ554wnzjG98w8XjcdO3a1Zx77rlplnbG8HfQjh07zAUXXGBKSkpM586dzZlnnpmyqmtq0WTMbvuxPn36mGg0mnZNvcdmTMu/86bnQL03HXILEWNcVbCDw1cd/fr1w6BBg/Dss8/ucdwjjzyCyZMn44033pCtkhwc2hqbNm1Cr169cN1112XksODg4NC2cPFyBwcHB4e9Go888ggSiQTOO++8tp6Kg4NDBnA1oQ4ODg4OeyVefPFFvPvuu7jpppswfvx4n8uDg4ND+4YjoQ4ODg4OeyVuuOEGrFixAkcddZTsOODg4NB+0KY1oa+88grmzZuHlStX4pNPPsHTTz+N8ePH73Gdl19+GeXl5fjb3/6GsrIyXHvttTj//POzMl8HBwcHBwcHB4dw0KY1odu3b8fQoUNx3333SePXrl2LU045BccffzzefvttTJ8+HT/84Q+pV5+Dg4ODg4ODg0P7RbtRx0cikRYjoVdffTWWLl2aZsI+YcIEbNmyBcuWLcvCLB0cHBwcHBwcHMLAXlUTWlVV5WuTN27cOEyfPl3eRjKZxMcff4zOnTsHaovp4ODg4ODgkF0YY/DFF1+gd+/ebWKEv2vXrrQuY5mgoKAg1KY2XyXsVSS0uroaPXv2TFvWs2dP1NTUYOfOnSgqKvKtU1tbm9bmbMOGDa6dmIODg4ODw16Ijz76CPvtt19W97lr1y6UFpVgK4KR0NLSUqxdu9YRUYK9ioQGwdy5c3H99df7lg+9ahFi8Q6pf0cS/qqEWCIp7SO+098vOJL0b89E/ZHX/Fr/ulFxv3TdpH/dKJlLrEEbl1cvzIUMYdtK5vmPP9oQvBqkIa59DSfJeWfLEnn+7TXkxfzj8sm4fP84hp0d88n2/HNpKPBvL79T+onOy/ef+KKODb5l8QL/fVJY5F8WL/SvW1RIxpFDLSnwL+uY57+2HcgTp6O4LB7zby8a8S/rSM5LftQ/rnOe/9jyyLhO/kvmy6JEI/6TEo34V4yRcRFSms+yNFGQdSNkXYSb4WH7SBryzCP7NfCfTzaObY+ta8jDhq3rBbs+/BiCR9jYNWPbY8eQF/H/gNh5zyP3VF6U/Pjqa/3Ldtb451LrX4adX/iXfe5vM2s++9y3LPGv9HH1a/3r/Xud/7g+IctW/2Vn+rTQgJ9gBTp37uyfXyujrq4OW1GHn+NIFGVIm3aiAT+pXoG6ujpHQgn2KhJaWlqKjRs3pi3buHEjiouLaRQUAGbOnIny8vLUv2tqalBWVob8aCHyYl+uEzOMSfkXxXf6X9R57MFIiAVDfowQSfJijREyaArJQ4+Qv4Ja/5zp8RJykZ9Mnx8j0myfDIyYgrzgGRhpZDcvI4gMDYRwsqnUFrCHvn8usZh/e7VF/hkWknVr4/5xCbJuNJ5+zaKEbJkoeSHnk+tPCGdBx3r/uDghauQ+ibGTRy5QLJ/cn2TdArIuq54pzCNEnJyXeIyMI+syshon42Ke32g85n+5qGSQkQiVDMUi/hOlEj8GRnzoOHF7KrmkxBQtk0tgd5o2ffvkA5uQ0IakP6rFxrH5quAfDuT5Qcglmwsjq5SENpAfVREhxNvZ84IsS/ifDaZ2h29Zcp/0c1r/+S7fmGixf/t1HeK+Zd06p1/HHSYGbOPnL1voGM1HEfnN7QlRE6GBGofd2KtI6OjRo/Hcc8+lLVu+fDlGjx7d7DrxeBzxuP8Gj+0DNOGgaGjQTsWuev8PPkqiiixKyRAr1CKSdF0xmllfS6I0ZFyURIMjHdOXsUgri5YmY9qDgpFrRiQZCaXRVhr11F6sbL+MSDKw/daTaGY9CSOyZQVF5Dx7yFWnzuQlSkhUUQdCOBm5pBFE3yIapexA3nksElpMxrF9sIhkB0IGWXS0E3uJEniJZHPrxigxi3j+RT7OaCSLPD8I2VDJpR7hC5uEhluTRyO6hpAm8vFsPNcxQl5r7Nwx4scIDtPuMqKrEkkVVuc44f/Ngzy7wa53TKQF9Gs0sud/N7PLgjgJMngCOXmm7TUc0RhAHvV7XscgEAm97777MG/ePFRXV2Po0KG45557MHLkyGbHL168GLNmzcK6deswYMAA3HrrrTj55JMBAPX19bj22mvx3HPP4Z///CdKSkowduxY3HLLLejdu3dqG/369cOHH36Ytt25c+dixowZmR+AiDYlodu2bcM//vGP1L/Xrl2Lt99+G127dsX++++PmTNnYsOGDfj1r38NAPjRj36Ee++9F1dddRV+8IMf4MUXX8Rvf/tbLF26NON9NySiQMOXv4aGWjFySaKKKmlixJQRkA5f+MkFSxUXiERXJ7Vsfum3SLTW/2uqI5E8Fn1VwYgpZSoEjBCzoABDgjz0+fb849j1YWAfDuweSCbJQ9lDwhoa/PssKvKf94Z6cn8WkA+HBInQkMgl+VbBLnIrFpBTwsax75UGcvwJ8hJit0o9WZe9ONj2ahP+SbMoqnd7NFpGCFMyQk4AWTVh/JEnRkoYWVWJJINMYMUXsZIqB7hXIDsHLLLoJYRq3JIRyYiw/d3LyAcwOVa2jJduaPuV95HHUvT+qCSihAIwssrSFeTjPuJ5lkXJ12l+of/9lhf377NjJ8/2k1Hg3/5pZBORaATRDCOxkQDk+YknnkB5eTkqKiowatQozJ8/H+PGjcOaNWvQo0cP3/gVK1bg7LPPxty5c3Hqqadi0aJFGD9+PFatWoVBgwZhx44dWLVqFWbNmoWhQ4fi3//+N6ZNm4bvfOc7ePPNN9O2dcMNN2DKlCmpf7d2+UObktA333wTxx9/fOrfjWnzSZMm4ZFHHsEnn3yC9evXp/57//79sXTpUlx++eW46667sN9+++HBBx/EuHHjMt53p071yCv68sfQQFIV9OVdRCJttezTjrwwWTEdwRf5/tQeI2ZqdJARHwYW9fOSsGTU/1BhKXq2jEVHWfTVZhyLPrLjZ/NTySU772zdOnI+aVlB3H9shYRM+tYjkUG2rIA84Nn2WSSUEcRCWsPpPwb23VBIfgKsTpSlxdVlDDTNTsglQ1605TpOlVgw0kj3GfE/AyhRIUS3raKZDDH2vGClQAT5Uf85SCRJyQiJQHrByTWJLLOaJMrBtVpPZW67t6fVALPzSZEkzw+WEVKXMZBIaMRbG08eIOw5k0c+ivM8GgLyaMs6YjH+TNzjOgGqOe644w5MmTIFkydPBgBUVFRg6dKleOihh2hU8q677sK3v/1tXHnllQCAG2+8EcuXL8e9996LiooKlJSUYPny5Wnr3HvvvRg5ciTWr1+P/fffP7W8c+fOKC0tzXzSAdGmJPS4446jqY5GPPLII3Sdt956y3rfyWQkLdrEIk/NrecFS4FSsFpCKi7StscisHmkZpWBkkQyzlvvaVUTykJoIug+yNMgv85/PlUiyZbZgInd1FxOHSmh6OCp2WT3onof19X5tx8nLwJ2CPUkuLWLRAcKydO3jtxkLDrK3vxREqll6XiGPJJ6Z9HWGNkee0Z5D1eNILI6R0Zo2Di2D0ZqddFQcAJrQ35V1Cf9kbughJPBprZQjXCq4+g+yD0QNZoADjESCSU1nBRsXZZ6zyP0wfN8Y5HQKBGUsgqAjp08+xQFu62JaIBIaDTDSGhdXR1WrlyJmTNnNtlvFGPHjkVVVRVdp6qqKk37Auy2r1yyZEmz+9m6dSsikQi6dOmStvyWW27BjTfeiP333x/nnHMOLr/8cuSxax0S9qqa0DCRTKa/sFkqkhEB9SXPyGW+SNZY2p5F/Vg63kaVz+AlcEyYZQM235h4nhhUcsnIKgOLcKqvWvYxkaAklER5BYKpklD1w6m2joi1SM2yDdilVS93kqbP2fkkx0GOl5ULsPR+HUlPxiLpHwTxGHFRMP7fCidqLN2rKa3ZPmzqEGmkUY2+EVBhkkhg1SiiN21PST0j5oRMqGIoBvUjQS0zYLXItAyAzDnG0uwkRR8hEVPDIqEs6knueeNN27N0CHESYen4Ig+BTZDyo2wjGg1YE4rdouimaE6vsmnTJiQSCWpHuXr1arqP5uwrq6ur6fhdu3bh6quvxtlnn43i4uLU8ssuuwyHHXYYunbtihUrVmDmzJn45JNPcMcdd7R4nEGRsyTUwcHBwcHBwSEbKCsrS/v37NmzMWfOnKzPo76+HmeeeSaMMbj//vvT/lvTaOqQIUNQUFCAiy66CHPnzqWEOQzkLAnt2Kke+R32bDzLUpYsYsoiTSyKGiWpPlZ3umMnSYew+ZGaQ1Y7ymoimV8lVdt78rGsrlWtOWVpcbYui1yyaK4N2PGzaDMtW2AeoyxiSlzDmJ0Xq89k6nWv6IhFM+X6T1JgVUIU+WIgg45jkUZWE8qWscgls1RSbZaY2l7dRwGrvRZqO5l1DosCxkSLJl5jSu5jMQLLo2+khpNZKomp9zxWx0pKIxjYcTAoEVM1Vc5ASxnEY2D7pb6wYuSbj2PKPtFUndXn2pRVeKKoEVa+JdaJehXzBZkWY7YCorHg6fiPPvooLerYHKnr1q0bYrEYtaNsrlazOftK7/hGAvrhhx/ixRdfTJsPw6hRo9DQ0IB169bh4IMP3uPYoMhZElpfH4MhJLMpGEFk2LWTeDqSFxzbHiOrLN/L0vuG/ChVn0yWVqeiGQ9JVOs/2bZUwikZ5IOTRlWlrtbd0pmIxDRZT+okRfU6TZd7BDJxQlSZsJVtn5FcVq/JnvtsH0ly2lk9qbqMganemTuw+r3C/HhZnWh+lKmU0/9NU7uE5DCLJqXmtLntqSIfbjNF9sHqEJlYh4CbyzMVebglPUFhQ0xV8DIA1pxAI5zq9qjqnXX7YfWfhpjVs6JNavCbvixCVIwRUieaV+Cv//Wm6PNEIWFrIhrRdVupdf4z7eLi4hZJH7C7xefw4cNRWVmJ8ePHA9jdbryyshJTp06l64wePRqVlZVpLcy99pWNBPT999/HSy+9hH333bfFubz99tuIRqNUkR8WcpaEeqFGOFViyuxzKNkQiRmDjWJeIZxsXJyQYVY3yWAjTGIRWJUQs3FJkaxageyX3RcsKsngJaaqII5F4Nn9TnqsoEDdh0j8bIhpXCaNjPiQ3zKrnyXHkScSUwUNRjNIV309mShFJVdcJMWifpoSnNeskvpHC5W/IrqitbhiNycGtXOT6jYgi9jILaDWycr+n0xFz4gpNfckc/HUjkZImiNC3hf5JIMT8zhuRPcgYs4WorEIohlm5TKNnAK70+KTJk3CiBEjMHLkSMyfPx/bt29PqeUnTpyIPn36YO7cuQCAadOm4dhjj8Xtt9+OU045BY8//jjefPNNLFy4EMBuAnrGGWdg1apVePbZZ5FIJFL1ol27dkVBQQGqqqrw2muv4fjjj0fnzp1RVVWFyy+/HN///vexzz77ZHwMKnKWhBaX1CK/Q+aR0FoWuSSQU/QsYkqIChWcECIRYb6J9f7LXKuq/OvSl23rQoz/xcglNci3iKyq46iBv2iXxdgVSx11iGvRHW6hRKJPzFLIM4613qRzI52QWApcTp+zoAiZL1uXWj6J6XimhGfLaDqeEFg9ld+y/ZLq4clIhOoJapOOV43U1cgq89NkYDZLsrJcJMleqGUQMgllBNmihaq6X/UesGrTykztxXuAwhsmZF6ihIRG2O+uIH0Z+w1nG7Ho7r+M1gmwn7POOgufffYZrrvuOlRXV2PYsGFYtmxZSny0fv16RJuc2yOPPBKLFi3Ctddei2uuuQYDBgzAkiVLMGjQIADAhg0b8MwzzwAAhg0blravl156Cccddxzi8Tgef/xxzJkzB7W1tejfvz8uv/xyn+o+bETMnjySvoKoqalBSUkJTvvVb5DfocMex6qE04Y00mirGkVlER+R/LIonRIxZYRTjbSqHqZqHaYNaIMBMTrKSCMDrXMifdypbx4jSJ5xjIQydCr2R9/ixNeTdUJipJGU06IzMbVn67KOSSUF5F4kLxxew6l1PWIkmXVMUjs1eYmpSvLUVowMalcmdRzfh7Y9Bpt2nCzlz8YpEVObjlE06imSMrWjFfUEFf1j2YeNep9Rs3qyzOzc6h9Xu82/bOsW/7Ka9HHms82+IYmP/NtXesx/UV+PQ5a8gK1bt0pp7TDRyBt+12ssOtBSh+axI9mAMz5pm3nvDWh7zwMHBwcHBwcHB4ecQ86m40v2qUVBxy+/Fm0ikg0N2le2Wk9qk/JXfUxpNyhhXRpVJV/2CTIuxsQw9cE9CFkEUT3HrE6yME7q10hkTK3FZCIkbzQTaKbzEYkseverdj3KJ8fKop4sHc+U8Kw7khr1ZOXDatSTBcNZSp2N49sLFvXcvW5sj/8G9KiVGrm06b5jkyq26TtPxVm0NSa5p1iNKVouXWCKfDY37tdJ6l9JjZMaHWVQo576dSQ/KqYeZGDpeFUdzwzsvaAdmUTNg+c5E6TzUNiIxDIXJln0RMgJ5CwJrauL8TdsE7D0OSOcKqGrF1PqjKjsJAp8Nb2vCqwYMfOSTrZPBkaG1LIAm65ClDSS41LdC1jqXTV/Vz8I+EcHS9tLm5PAynhpvSZZxkRDNuPUXu+McKpG92wfrJ6U9ZNnD0lfFZP4omHVT8zuJ0qV0VqtJ0tjM+Jno9xmJExW71tYACniH2o4L86NCY4Y9LpORvTJh5NYYyp3YGLnmBFTJlZKiGp7hejS1p7k3dOJKObj6WVEeUzwkGVEo8hcmNRKc/mqIGdJqElG0ggVJWCkgwz7CmLkkkElKiwCyyOQGihpUtsdel78tNaTHBebLyOmyZ1sp9LUZNW3CpWsqzWhah2vSpwVsOOPksilasekgrXeZIIjtl9W18nIJeuOxJSn9B1BfqLMpY/tl5bNe16IXFxEbMpEAsYU3mr3HQa1ExAnZsFvDJXAqlA9UH1jxF7vNpFgul/x+JkDARMIqQIudq+wjklMHR9hgjX/mlwdX+j5Ve3015xGWFE5Uft4hUkxkfi3JnaT0AzXaZ2pfGWQsyTUix3bSKRRjGQxv0ZGJBlZtWmzqJJa9Th45Dd9zirhVI35wVSRIrlUCSclZmy/Aa2SgGaslwjpZmCRb05+08epRB/E/oSl3hl5s/H1ZOMYt64lkZc48QRk6XNGnJngiIHvl0TMqB2PJx0ve2n6t5VQyYYYudSV5iF7UxKoIiGVrCnjWLpbJZxhm9qrx6/cY81tT/2woVDXZRFTtswb+WQpe2ZgT75Yvc+yPBvVfkiIRiOIZvi1zrIaDl8iZ0lol5I6FDQpjNtZFLyWUDWrV3vRUzNwlXCJJDRhkcr3Qq3DZFAjwTYEUY0qqqTe27kok+1xJbw2TqkJZTWXrP6T2Syx+66Q1c6K9k5sv2x+qlUSS58zNTtTwjNSqxJYRX3M+qvT7ki0IxG5J4hiXu1mZNM7nkEmxORly+ySVLBfPKsT9e6WEzXiXqDWpopEjV3HWCCDnuahdr6iYJ6gjNix1Dsbx5YJKXrWMYl2VspPXxYRgymtiUAWTW0fwG3XcJFiBwcHBwcHBweHrCNnI6G1DRGYJiIjNVrGop4sSqfWcKptFlVFtprypylgwQ3ARpRDa0LJOOalKQuuyDGwiKEaWZa7EonjVCjbo1Fqsh4VIZEMI9PpMaGOmu5WxU+scxGLXKpoYHNmvwFRWc/bakY9Y7SoJ6vhVCOcqhjGxvzeJqVM55IUOyaJEV22PWVuqlm96higCphU2LRLZaD3LItwRkR1PF2XjMvzjGMFlKSWlLXyjMTTt9X2cdD/dExy6fhQkbMkNBo1Lb7owyacTFlPHSxEQsMIFyOXjJiydDyDV4TD9smg1oTa2CwxgZCNGCgbUIvaFUEUM5y3AXu2qm07xc6tFCylzttxsrU16yVGpplinqEor+Vxat0gJ2DMAoiQIVEdrLaPtLFoUgmxSjjDFkRJ+6REUhQwqfWq4n5VwikLk9RXu9rKk4ma2LKoR2nKX3C+RbTHvOcH3x7S8YGESS4dv0fkLAntHgfiTTrKMdUuCkgEjdxQO0nnGtV5aGdtuBURdXViBJaQUOYG4K1XVIl0HTkutV6TETBGpJkgTI3Asg8HtQMR+xChEV1RRc/GMW9P7/1TJAZKSJBB9vpU6zrVOlFW/8minmwffByp/xRrQtnvsYC1aSX1mZpPKInu0G45rOZUezSr9Z/qOPaaV0mo2vWImheQc0zFOoLfKbWKouRaO8cqkVSjuewDQ+58RZdpynpKLhv83dSod2iY7T1D7n6XTUSjAXrHk+vt8CVyloQmTLp6l5FGVQXMXt7UYoYROFVFTqCq0hm4d6a0qrZ9kXDu3OG/BVk0t6iD/yFoU7ag9mtnUIVTcqkB+SBgScfOHssSqj5n96LqWy3+BthpsrB7RZI+pJltE0sVa/ZObF0WMY0QyyfuO+ldT/N0ZNtX0v1AMzZDKhEQxUXM2oftN0lS/obZZRFiysoF2LpqxNRrKWRjs6RGeFXIvp7iXPixEbCHeVJNs4speiEdHykq9A0x23b4N0W+YqOd0z9MovltT+YiMYD0rtjzOq0zla8McpaE1hsg2uR5xiKhrKaNRWjYi7qWRBXVekoG1etTNWFXbZu8EU1uaK+RQUa2VHsr9Rho3alYeqBp3nWyzoipSmBZJNR7j6o2S2qtJ4cWMbUBJ43auiyVz44tQYge++lRM3lGTFr5zcJN3tlIrdaTkje5/lH0GKUETiPJqi8q24NCHGnHJFJfqlpP0X2oRFdM5dvsQ+bSNpGHfD/BpP3pvWD+ogS+dDwz0c8yXCQ0fOQsCY1G0l/Y7MXKolGMcLIXZryA+QGSdUVxjez1KY5Tya+3Sw+Lqlm1zwxZ+EM9QWWLJm2/ai2qSjjZvcLgJZ1qHSa779hrQK3/tCGhrA6TpdQ5ufRvj9WOFpHUuxoJZaDRS89bXo16xqBZFlFiJfqJMtCop2grxes6/ZBrTGmU14+gPp6quMquXjV47azcBlU+Dv+67PpEY+TeS5B0PKvZjJF7ha3rvbbMJ5R1USoipSueOtEIaxqRZUQDRELbnjq3b7jz4+Dg4ODg4ODgkHXkbCS0a9ygsPDLLysW3alj/d9Z2p5kq/ROM/6Fu/K1xDDbHkvRspaKNDrGWpJ6jleNtKqR0Gx0jFKN7hnUkgcU+RcxUQ9tW0nT5S1vjxnOs6giQzEJhaqlJuy4WDSGiYGY4TwTK1GRVL7/RmYiJDWyykRIQQVGiqF9c2CRQSaGYmDRTDXqyTsXaXOmJvQseEsjoWLojgzjaXstveuFKppS62npqmSZWhpBtyd2yOJCIhK5ZMIkku4yCZIXpCUjwnOf1qu2fZRTQSwSQSzDdDyzh3P4EjlLQgui6YRtp1g3qKbtWR1evdbmmSqNGZFkc2Hj1H1s20UecEIqW63XtEm9q4RTbqkpColox6RC/0lmZE3tp86gpNrVZ6HNfaISUwZVSMTAUuXq9tj84jT1blPD17I6nkFVQTNY1XrSNLtYm8dII4FMOOlOyPZYHW9A70zVZsnGw1QlpgzqPmhPeAYmOGLqeNGiKZIX9y0zDaT+00swWeo9Rs4xS9t7Hz6iPWBrIpBFU9tPu10jZ0loUKj9tevI7119obNoq0pM1X3sqCciIVKbWC+Qc9WKiKr5LUzU1Ain2sOdjVPrH9n1UcmaWtvp/bBhwQN1vuo+eWtL/zKVELOIpKqOZ6gn6+aRdZlYKUp9NzUS4t++KKwRCC3Ao3SMlFE1u9iHXYVaT6kq/3mfdM2fUx2nQBUDqedOFRzZiJ944wC2D/Kbsukxz8CIrm+MuE8iVvKp48PtgBoIgczqnTBpj8hZEhrUoommsdlHnDoPkazaZCvYb4aprxm8hug2VlY2fd2Z4Ej15mSRWlU4xsiVGllkxFQmicK6LO3MyBYTHNF7gtx3bJx6nhjhVMG2pwqJ1J7wNl2EFDACxjbFiEW+GKXkQhqNNFLrKQJGdK3M+dXnBV235Yih6hPKYGPRZEMuGYFl1JeVaXCyTh4gavpcPQ61s5IXoljJ9xBoB/6iLhIaPnKWhO5KRAAxBd8U6kt5FyGXrF7TBmx7rOaQzY8RJLaul3RycuBfRokaeVgyIskiptTXkxBY1cpKPSdq60n1vlDBXAi8kQCVcNo8u9Vrq9ai8n2w6Kh/HIuYxqNalIq3H2X7EL0zPS/gsGv65E475Dyxek0WqeU1oZrfqex1mQVFs3cf6nx1xbyGoB8rAI+GU1sp1pKVRr7ZRxeBDTENWv/J4G33CSDi+YF6/90WiEQNIhk+6zIdn2twHN3BwcHBwcHBwSHryNlIaFlHg6KOX36hbFedygnUtD1t+UmisSx1yCM52jg1mredCKd8KWAWLRTV99xcnYUktXapNrBJPSup8uaWMfBoo3/H3kituk8W4VVLBVhKnSncmRKeoQNpF8r20SmPib+ISIwZ/dPyAy0CyaKIihhEVrOLaXY2D5ayZdtjEVN1fpEkS7Oz+gtpczzaKtcAs3XZXNInk0e8WFlkOY8UGbLoKL8Wano/uMcoAxOT0fpPlkoJqmZvdjJCdJSmzli9GVPkmz3+uy0QiWZ+ysIuxf2qIWdJaF0SaPreZKREVRCr9YCMSLH0KbOGYlAJp9rekZErLxSi2hzY80gll2paWCW/dB+MJJNx6n7VlD+7B5Rrxm3F/Mu6aPyDgu2DEc5a0s2kkygIY6SRCY6YHZMNVKWxQhpYfSVTcqtiG9bNh6eKRdGMOI7WNapvUbmFqLg9puZm63oJDBUDqQbxWmpbJZd8H+Ga2lNQokfGsWumLmOCBi+Y+l42sE9/cIkl4a2KSMQgkuFEMh2fa8hZEhqLpL9gVQsk9pLvTIIbO8jvkxEptj0WBWNgJJRG+ESCpNSYMm9KBloTSs4J254aWVWhRjPVHuvEYURuoalCiXKy7at1rWycOl/WpYgRSRblDjuYYaOsZ8QxxpTGhAwoNZGcWIjtGVmdn7g9WjvK5stOkxq1YQREJbAqyVFELoAvYEoV5Ca4NRZX/fvHqR8dDKr1FN0eOzZGutn5VLoeNbcuHResTjSSTyL6XmLKHrxZhouEho+cJaF50ZZtamxemGrUT43S8aiaf6FKpllSkAtd0v/NSLONoIeBEcRCNVVOtqcKuBghVq8jN/9vOaUOBD9XNnX6qsKdEUnan160XmJRVHY+80KOmrOXskoIpX0SEmHTD1y1+5EjsEyoIloqUbLGSInY796wNDu7Fmok1LNMJ3TBzwkXjgWPhstkVRSE0Q8MG6LPTO2VdDwlqqzuhz18van9tmdzkajJ2FbQCZP2jJwloTsagKaCUUZU1NQ2IzQMSkei5vbBoPpE2mzPSy6i5CGtzoNlZtToKI8Y+5cxqOUINh8d7Nqye4qdK9Vj1LsPGw9PPk4jkqpBvApWE6qq7bnvqH8cqwnlL361n7hHkR0hBITMw1BvUv/c1FIBNZonR9AoaxRrCWmEk+xD9mhi0beWV6PEnKnP6b0tOgFQ9y2WZg/uz0rLNFSXA+odKl4zas3Btkd+K95rRtPx4te+1zuUvSyzjEgkQCS07UX97Rpt/2nh4ODg4ODg4OCQc8jZSKi3JlQ172ZpYVXApPZrZx/PNqIeXbndstpeFdGoSIjhTDWNrXYCUqEavdsIkxjYPryRVbWGU28VqqXeeTrev4yl3un2yFxoip5eC23OappV1YJ4I1wxosi2Mi8XhSpqSYGcxuVr+xexVDmNqomtIhnYc4WpqD3XVjW5Z/cES5/T6KhF/afuTxq8xzwFS2c3MPECu94W9b6+bWmR1vaajnc+oeEiZ0loPKa/nJtCJZJqnR9LM6vbUwmCmvJn6eOdAsHWSZl/YHE+cwfQ1mU1sQyqvZX6IcLLFsIdp4jY2PWSW4BapMBZ2p7dY3EL3sOumd6Ok6XotRe/3APeQ8w4afRDrVdUjeRl1bvac92mllCt4VT3YVFT6xsTck7UJs2uNidg0FPv7KEsXh82TnU+8N4XbPtKdyS2rB10THLCpPCRsyTUmGDekyqRVDsSdWTqcEZgQxYGqsfe0ePryKK+6jGwt4967naJ4VYeQVTJlTRMvrZ6BNK/TCGYKuFU212qZJ3tl+1DbSvKwL0+pVX5S5m16BSjXqw+00sk1AgV63iTF41L88hTuxSpPX5YRMrmjRk0Mma9rnc9bTVa/kqFWdo5UV0JWF2nlTpehaqOZxFHw6gCWdcLFvVWI5rtMBLq2naGj5wloV0KDDrEv/xBU69C0a+TiTfYuoz4qeQl7FaRNn6iXqjkmkGd7z7E67KtBEfs3DEvTrWOntbkswCCQKbZfczM5Rlp5J6g2rr8GNg+tBQ9M6EvIAPVNpjM1J2P89+4vF93+rFxGyfSilC0e4ow4qtGM22gRjMbdgXfh0o4ZX9SzziLVpRyrI0JdYxWBkBFZ+IHAY1os0mrfnMMLBLK1PGK1ZY6DyZWaodwPqHhY++48q2Awrxk2ouOeR+qYOSAvaiZoXdnko7m3ZbUSIt/GSWSLNImbI8RK9UjU42+qqRRJYjqugyU/ItcgH1McDW3llL2EkJGOGtJBy6VcKpQ0/YsimqjhKeRJnLX8kgoS2NqdaJMbe3rHR9yzo1G36zIhkWkUU6zi9tjJIdG30Qy6Z1f2BFZ8drKhvNqTaRatiCPy0IXJS/Ue4fA6x0aIR+m2YZLx4cPd3ocHBwcHBwcHByyjpyNhCZNJC0CZVOrxsUR/nXVDjLcm5Ftz7/MJsLF4N1eASsrEpXhqhMAS+Wr+1DFRUFLD5rbnhrlZRF3NeWt7FOtpVQN5xnUe5tB7XDEoqhytyGxLaINgqqUrRTzVOGt+nWKUSSbddV+5axOUBXNxAJGKlVxlcWx8mtBiujpfsk+6C0m3j9KxBjIQHDEHqJknPfaUq9XsZ2g9xyrc21FOHV8+HCR0BCQNP6/2kTE9xePGfEv6fvLixrfXz75Y+PYH9tvfhQt/rFtNdpdNf0rjPn/1HEMbC5se1HyZzM/9qeCnWN2/qIR9uc/DmX7bD3+19x+0//Yvch/AxHfH5+fdqzsLxKJ+P4YIpGo/4/8z5D/sX2w7Slg+wx9nEn6/xjYuGSD/4+OI3+Necmmf1Hyl1fg/2PrqttLNPj/lGNVzwmbBztP0Tz/H0Msz/+nzsWA/FkcGwM7XnVddg689wm7hupfQX76H2ntmW2ot673Lwjuu+8+9OvXD4WFhRg1ahRef/31PY5fvHgxBg4ciMLCQgwePBjPPfdc6r/V19fj6quvxuDBg9GxY0f07t0bEydOxMcff5y2jc2bN+Pcc89FcXExunTpggsuuADbtm0LdgAicjYS2kjaGhG2nU6nfP8Pd2cDE1FoNYIq2Lqq6IpHedP/zSJ5rM5PrbFlEckOef6FrNZRtZnS6zD966r1lHodq401Usv75O0z2ThplzTK3zHPHwJh95NNtFn9nanWO4w42ngu+mpCxQ41Nl6fofdcVyOB6kuU1qda1CHKtZOecartEOteZRMxVY+BwaZw0KLukoIR5aARU7XnvLStto+Z7ebHmUU2g0z7iSeeQHl5OSoqKjBq1CjMnz8f48aNw5o1a9CjRw/f+BUrVuDss8/G3Llzceqpp2LRokUYP348Vq1ahUGDBmHHjh1YtWoVZs2ahaFDh+Lf//43pk2bhu985zt48803U9s599xz8cknn2D58uWor6/H5MmTceGFF2LRokWZH4SIiKGOzV9d1NTUoKSkBM+tuR8dOxftcaz6ImRiGCZCYilG1RaHgb34WWqXETgGxaKIbUslUaqLgLo9lfipLSq50Md/HVXSre5DJcnKeeGG7v5xLKLJ7kWqto9qCndFXLV7fmwZI3Xst+KPjjD1MSOE1HqJqI+5ej19e2weqtdn6N6PKhh5UYkf3Z5Fyl+FYnSvklDVNJ/Og9hhZMOOymZ76jLVJ5TcP6Zue/qCeuKiwJbV7vAv25a+rGZbLboceQu2bt2K4uJi//hWRCNveP/Mseicn9l980V9Awb89oWM5j1q1CgcfvjhuPfeewEAyWQSZWVluPTSSzFjxgzf+LPOOgvbt2/Hs88+m1p2xBFHYNiwYaioqKD7eOONNzBy5Eh8+OGH2H///fHee+/h0EMPxRtvvIERI0YAAJYtW4aTTz4Z//rXv9C7d++MjluFi4T+B4zQMIKk2jZ1IEo+tg/GK1RlPbOxUSOQvJ605YipSnLVOke19zclOWI9rVqvyKCSNXWcSmDZOO8yNSKrWDvt3h7Zp3zNNAIbNuFUCCKgRz255VPLj0lDXtKst7aNWb1cNxl2ZIzZ89jUU6pzYcehwKb+VbaFCrk+0WYu6nmyqf+sJ/eAzQeQAq9tU14r709ABAEsmv5T8FtTU5O2PB6PIx73ewTX1dVh5cqVmDlzZmpZNBrF2LFjUVVVRfdRVVWF8vLytGXjxo3DkiVLmp3X1q1bEYlE0KVLl9Q2unTpkiKgADB27FhEo1G89tpr+O53v7vH4wyKnCWh9clIi4SSpyw1csVS73wfGqnjaVbtxapGFhm8ZJJHt7RtsXPC09P+ddWoL+9UFVxww8dp28sjEUNul6WdAy/UdDxfVyPS6rmzucfUlDpvY6hFM2065vD9etLxol+pamjOhSpqDUXIUT+1OxJDHokYqgImFUE9MW2IuU0XKZu52JBkmxKCoOdKjRizDx3FDD/LsLFoKisrS1s+e/ZszJkzxzd+06ZNSCQS6NmzZ9rynj17YvXq1XQf1dXVdHx1dTUdv2vXLlx99dU4++yzU9HZ6upqX6o/Ly8PXbt2bXY7YSBnSaiDg4ODg4ODQzbw0UcfpaXjWRQ0G6ivr8eZZ54JYwzuv//+NplDU+QsCVXS8TYiHL1Von+ZjaG5akfEwI7DG+FiUVq1DjHsekgWHVUN0lWBjE23JdUaKWg9ZZC2s41g51P9wLepsaXiItZS06KvOzOXj5FHnSoIYpFVr+UTS8ezjLoe9STL2Di2Xxp9FCNyqiiFBR/VqFrY9Y9BW1my6BttMym+Jm3srdTyBrE2k26PLaMRSAJ6DvzrRqLpJTOmYbtvTOjtYrMIG4um4uJiqSa0W7duiMVi2LhxY9ryjRs3orS0lK5TWloqjW8koB9++CFefPHFtPmUlpbi008/TRvf0NCAzZs3N7vfMJCzJLTAYzfDCY1/PTUtzGAjpGFpZjktTJapRM9LdFRCp5Itdu4YuWLH34l0m2Lj+D40Uh8NucZUreNU0tsqyaPHJW9PI/VsXVVcQ5eJHY5oGpyl4226KNESnGAtNLnXJ7lo7PKzPt8WHX4oWCpfJTkqQWrt+annhJH1sDsNqctsCKxNip6dA0ZM1XtPuRYq8mJ7/ncboNE5LNN1MkFBQQGGDx+OyspKjB8/HsBuYVJlZSWmTp1K1xk9ejQqKysxffr01LLly5dj9OjRqX83EtD3338fL730Evbdd1/fNrZs2YKVK1di+PDhAIAXX3wRyWQSo0aNyuwgMkDOktCEiaTVxdnUZqo2SyppsDGm14U+/oXb61tWgsdFYq7WjqrG/CzCq+6DRz2Dq8htIpAMqnG8d5xKhvPFXu90XVFcxKC22aTEjDzpVeIX9jhF1KSTV5Fw2kAlSDZRNRtyaUPqlHHqemr7UDaOkTcV6rmzOZ+qMT2tbWXnQJuKdC3UFq3ece0gWpots/ry8nJMmjQJI0aMwMiRIzF//nxs374dkydPBgBMnDgRffr0wdy5cwEA06ZNw7HHHovbb78dp5xyCh5//HG8+eabWLhwIYDdBPSMM87AqlWr8OyzzyKRSKTqPLt27YqCggIccsgh+Pa3v40pU6agoqIC9fX1mDp1KiZMmNBqynggh0moNx3P0DHf/8tTbYYYeWGkViWcXG2vEV3Voqkj8Tb1pugZyeXHqllU8ciqlnpXiRRDXC6XUP1Ubayr/MuCHltc9hzVopl8Xf8yNerJbJFUJEn3GaaYZ+SXjWNQjegVAkuV8GpKnW7QIp2qwsYTM2xxkSFEKr+QrOsZp6bP5XlkIcJL5yJ+EDCoNlWMELJ7KswyCDUqXV8fbPutiGz1jj/rrLPw2Wef4brrrkN1dTWGDRuGZcuWpcRH69evR7TJtTvyyCOxaNEiXHvttbjmmmswYMAALFmyBIMGDQIAbNiwAc888wwAYNiwYWn7eumll3DccccBAB577DFMnToVJ554IqLRKE4//XTcfffdmR9ABmhzn9D77rsP8+bNQ3V1NYYOHYp77rkHI0eObHb8/Pnzcf/992P9+vXo1q0bzjjjDMydOxeFheThRNDo9/WntfegU/GefUJZjaSqKlYjfIyEMQLHEDbJUeynbGozbaCSsjAJXXPbY1BtkPg+/BMsYsTec55ZhJOBl0H4t6+S6wKyQTWqqJJQtSaULqPq+OCR1byI/wXpN6sXz0nYXpqq16esjBYjaDYklM1PbeXJSJNXgW2jSGfzUL1T1XFW5vfisVnVmIrjmN+nd7VErbYeI747vkj7Z80Xu9DlG3Pa1Cd0/eTjUMx6V+9p3boG7P/wy20y770BbRrfbuwKMHv2bKxatQpDhw7FuHHjfMWxjVi0aBFmzJiB2bNn47333sMvf/lLPPHEE7jmmmuyPHMHBwcHBwcHBwcbtGk6/o477sCUKVNSdQ4VFRVYunQpHnroIdoVYMWKFTjqqKNwzjnnAAD69euHs88+G6+99lrG+44inYHbWA/bpOhVr0/d61KT2rIIn9rRSVnPptZVjT6qtbhsHItyFzDnA3LumIG7Ch5J9o9j58Cbalcj4TZRWhb11MVFWitLpixX0+yqwp1GKgXVOxByy0+biBetzQy5rlONjNHsrMVTVPU2VVPjXqg1lywip9Z/qtHcJDP/V1tZWryy5RpTi2g4EzAFRTsUJjW2tc90HYfm0WYkNEhXgCOPPBK/+c1v8Prrr2PkyJH45z//ieeeew7nnXdexvtPIp14svuEkZI68TlL052s1pGsW0TU3A2s+w7tPkOIJJug2DFJIqFiGpeNU1HPyCUZx9LYDPnkHNMORxa1kyrUfvJeFT27x9jHDyO0hUSFparAZcIpErVYlLS8JPtgoiYGta5TPl6hraZshh9UbAM007mIHSsjQxZdj1SLJjUdz0gdI2vseAs6+Jd507s0pS4Sc/m8k3VjItuwaRxgU3ebjXW9pRENO8PdfhsjW8KkXEKbkdAgXQHOOeccbNq0CWPGjIExBg0NDfjRj360x3R8bW0tamu/rEtpbJ3VMS+JTk3EPixaxvuaB7fsYdE3WvvHWnSKKm0WWWSde9SIoSLgUduRMihkK5PtqftQo6iq8p9BdS9QfWY75KWPY+eJt0tV6zX95ICVjKutJ9WWmlatLMX9ylFPG5W7Fzb1n2pLRLWVJ1sm1PQ1uy5VUIdc7xo00qbWkjLInpui8IfBZn429aQMNuQv6DloByr3oMiWMCmXsFep419++WXcfPPNWLBgAUaNGoV//OMfmDZtGm688UbMmjWLrjN37lxcf/31vuVb62NI1H35YmJRMN0qKbiaW089a/tVBUfqul4yxIiaGpGtTfp/jbXkBLCIMbs+LDocJ4R7h9hClYERWFWwxsCilwzsWuzwioDF6Fs85o+Fs8hlg/G/4FlaPEHSdUxwxMibIQp3lmZnKnI2FxqVNSTlz1L55O3AFPhcF5w+Ti094DxajMhRRTZZl5FLVZHMYNNSUo38siiiSq6866rRTLZPqr5Xo4Ahv05t2IuNDZR6HGwfnnUj5Boa8oyWzPVtShHCQjSidzppuo5Ds2izqxqkK8CsWbNw3nnn4Yc//CEAYPDgwdi+fTsuvPBC/PSnP02zLGjEzJkzUV5envp3TU0NysrKEI147w2WEvXPgZFGrpi3qeFk+9VuZEYIoyw6xvYhbN/GqJ0RRJZmp6l8MS3OPUaDuwgw8GurugZo+2DwRlHtanG1HuZ6X3d2jwVP+fPaTK0MgNaYRrUWeepcvMvky6rWcKqgUbUsqK9V2GxPrUX1HofqfUnnJqbyY4Ss2kQ4WY91tWbXZnsManQ9KNTr4/1IylPNSlsPkWgEkQxJZabjcw1tFihu2hWgEY1dAZq6/DfFjh07fEQzFtv90mjOaSoej6faZaltsxwcHBwcHBwcHFoXbRrfzrQrwGmnnYY77rgD3/jGN1Lp+FmzZuG0005LkVEVUZi0qB77WFFrM2MkqsQifKqvo+pFyqDWWDKodaIKWMSPmeaz88kjnCxi6t+HTYtOXgMcPPLbiZj/q1FexZ8zweo1SSQvxtTdYoRTrte08OHkveODi5/yLKKeXJjU8rd6hH0E0+5INt2MQhZvqJExWTGv1l2K69JzJbyyVKU9g805UaHWurLzpJZQ2Pi9hukBG+Y+24N4KRbVBWhN13FoFm1KQjPtCnDttdciEong2muvxYYNG9C9e3ecdtppuOmmmzLed1EsmUaAWJ1fPgn/q5F1tSOP2m1IJYO817d/HE89ByNcar1qnnjyOCkj+5VJqEbyWFcqtTtQgjwgVcLJSGJ+VOjIo5I82rmItc9U61U1wRGdCyWcoml8lAmnSM0ZKSzJj2rNLFQRkrd2lIqmbPqBM4TZ2rK5ZTYpZZs0rqpAV86fOl+b86kSXfrhoJrai3Oh6e2A5665cVScxo7Xcx3DbIjQHhQ+/jo+bR2HZtHmHZOyjcbOB29+eHdaxyTVw5NG8wh5YbARJjGoHW4YwVbtfbz7CNv/M4+QLSYOYaD1gBZeknQfhIDQXueqRQ/bh9jy0rtflbwx2PRNVz08Wachr6Bn9z40YkqtkqiHp3+ZSmBV8ZMPqjKaQW29adMJKOxuObLgSPTOVAlGmKaLYQuuVKjdltj5ZMKpbFgvsXtUuX/IPWuSxJiQienqdqT9s+aLXehyyNVt2jGp+oqxKI5rVnGpdWvrUXr7C65jUjNoB3KztkFBLJlGxBRlOGDX65yLa1qa6W5wc3n/yiwNzsoFeCtHZlnT8gSVqN3ufZKoFSF0jDAwcDJIVNoiaWSpZ6oEJ8er9itnBJutS9PgnimrRJJF93SCyEoZ2PGLBFEsIVCPjZ0n6jvKIpVqmULQNLiqUA47QqWSWhtCZ+VtauFPSpwPpHlQ2JQFqJFBUdTE5qL6napQiS7bLzs2Rhy92wuT6Ft86IcGFwkNHTlLQpMmkhbp5NE8TX2s2vMwKP3amwMjdYw0MrLKECMPGiUqKUcByQskyjryiBFOFap9ECNh6rqqSpt7XWr2Pl5ixlPRwUke9chkBJ6q3jVw4/fg9Z823Yyodyg9XgIvGVBftmoUUCUb1HpITAurBFH1zrTpqqNY9Kiw8SFV+9rblFqoHwksAGBTF6x+FKlklcG7PZsOSu3RoikWwKLJxhIlB9AOrmrbIO6JhDKo5E0lSOzFqkYR6fay8GXonbNNmtSGSEYIYYiIkRxGJBkY4aa+kRaRO7XuksFLmiLkoUzN5cMWCLEIvEgamVWSHPVkEViRiKv3gJXBvLT9kKOjDCrJYb/Ruu3+YXlE6CWSZOYLG2HX28bXUom+qR8JlHCLRI2142RCItkuy8LOi6XtbQz21YipZ5xpEG2V2oPoSEAkEsCiqT1EcNsx2kGlr4ODg4ODg4ODQ64hZyOh2+tjiNR/+UWuWhuxyDpXz4rm8lHtS5RaBTFXGFG5Hdy2R6ubVFO7cqcdVjdJTwlTx2u1jgw0Rc/OMSu2J5D7miuRAeoKFDw9zaFF6eh+aUciIlQIWUwWE22leK2wGJX0RqHlDkcWNXIMLBplofCmUU8LEQ6L1oeu5vbvNNh6gG6BpJ53tSORGoFX56eeO7Y9NToq3GeRmP9+Mola3zIp1R52R6ogcBZNoaMdXNW2gVeYxGozGZHU2y5qAiZGBinhIiksSq4IkSwQhT4MChlQ09NU+EJ7iQfv0kMrLsU0LjufBdEi3zK2X4j7YFDPlddmyK68QXsw2pQeyK0sxf2qZR/0PqNesSwtLKZAvftlIg2W/lTJqo1VkkU6PvSOSTaiq9ZO0dLti2lxlVcEFVc1B1ZWkg11vLquQKZZOYYh+ob26BPqOiaFj5wloV6oLSAZkWRgYiWuwNeUxjZ1cyokU26RIHJrH034FBXVzbw3uVify7w5iZqdrku3Rw1FJTCBTDIi9lj37ZIJ09S+7pqQyobAsn2oRJe3BtXIKoUafVJU33LdoEUrRgabqJUqGlFrGNX92hipE3gtf+RXvlqHST08xUioeqw2jgE2dcyq6IrNL6joqD34fQaFEyaFjpwlofGYSbMksklj0y41VNDiP93cKkjzYaSehipJlM3F07dnQw4Sxp+yLoj5I40MjKxSUksV7tr55PeAf84yCRPTwgzMQskLGpEVQT9WAnYL2j1OU71bpXbphwgrDSDHxk6VTS/tIGMAOxKhbs8GLDPDSk0S5HcRIx9x2TDnD2u95tal0VELgZlN6YbNRweDjUNCwPuW3k8M7dGs3pHQ0JGzJHRnQxR5DV/e1Nwqyf8y420cGeH0j4vH/D8+2plNNHBXjbWpYjogQVIjaDSSJ6bAeZkBI4haLapMJEXLK3aOuf+nRhLZ9hLkoa9EudVoc0Sdm+p/Sn4DNkp91ZWAQuXmMVG5rLz4bSKDrGZMJW+qbY1FfWqEZCbonFUrJwJWJ0j3S8DqDsnG/Mss5ivvw6Ytpg1xpqUbFjZi6nnxzoUeFgsKEGKabH8k1KXjw0fbX1UHBwcHBwcHB4ecQ85GQvOiJq37EeuExMAioTyNy0RI/i9ArrYn6X2xXpGnWdkiLZrn7T5DWx0S4ROPNPoXqWUGNsIsu7ad2rh89nVvUZ9Kz4sQGeJen6QMRKyxpfuw8AlV1fs2tc30MNTIoo25thc26f5s9J0PU5HeHCyETixtS0stfIMsop40ImkRbQ77HNsY8duI3bLhb9vSttqBMMml48NHzpLQ/KhJI5Rq/3MGWS0sEiSb+k9KaMQfL+2v7Xmjc9WylrLmZNCmh7l2+6p1rHZWRn4wY3YGKjgSRD2MXKotQFUrInq9RSKpklqrNLtKctQ0e2u+RG3XtaprFMcx83+ha1pzoEpo1cA+aImDLOgh159ZFrG6Sdn8ntVcWpRuqGlx9d5OWNSsKvXNVLDnX0Tr0tpjTWgkmnm72/Yw73aMnCWhSZNOPMMmnGHb57CoHyONco9suVWi598iuWQEkdVwMrB1GeFqSJI6MrVFqShWoopxC6KvCqLYfZbw1E3x6LB/nzIJV+2TmJqf1UVT8qtF4GnNsk1bSAamDGaEI0y/SrV+j9n4WETBqBiEkUtVHc52KxJJ9qilxFTaK9kWi6CyLELQ+l9A96xURT7qunIfe/a7ENXxquCItlr1bks7VuZPa+pCtgsLAZFYBJEMI5uZjs815CwJjUbSiadKBNQWXPkRf9s0NVrEyKWqDld7bgdVVsueoKKwRC0z0OfCCKwqhvEvU1LggG3feY2YeRXz8j0herHSjx/q7UrS50b8rZAXXJS1rlVT6qrNDgUZR9XCwj5Ui6YsRCkZGVQjkrI6XISshGbrEgU+J+zpx8HdESzKERSLrubGMVKrElirkoyQ1fHqOfBeHxv7qPYYCfUSB3Udh2aRsyRUgRyhIaDvX9U+SPRrZPPjSv3wvEPVeajElO+T7EM0umdQOybZpOPVaCOPrGrE3nu8zC6MQa25tLrfxQ5Z9AwHNYgH9MiiSgjVdKzysrUhnKqqOsFcCSyIKTMNF4kE9WCXo2rk45lZPjEo6XibOlEb8kNJs0Vdp0qcVfJr89tTPEZVH1JlmasJ/UqiHXxaODg4ODg4ODg45BpyNhIajcRajBDy3tJqLSUT0pAaTrmGkaR2RR9GtQ5RiZjx2lRNgMOyEqq4ipkX6Ip5rdTCBmoU1RD/WLU+13sdeUciscyCdR4VBWYq+DkhO2Y1nDa1dGp0h0HtKuONNKkCDDXNbpHGViGnu8n5ZFFKWcBEswFke/U7/euSSLVyh8pKezUKaOMnyu5Pm4ihGh21EVgxKBFnscaW3jvtMB3vfELDR86S0KRJpBEWbvwefPs0FctewKIiWSfELF1O7KJIHZ4iwlF73fMyA41wMrDazAR5sbBxQUlec+B1nVo9pSrWYYgEbDCg1maq10Ku12T91PP9tdJ8HxadhVTCqab3lZe3TcckUbxDNycSP5kgisIk0+AXBVJQAk9IN9se/RAh63q5ilh6YGUaL4vkxLS4jRuClf2UjSG+xYeidxpSaUjbk1DEorxRQ0vrODSLnCWh9clIWpekSFTr1R0hhU82innWnpHXWLINEuJDOhWpJEft9e3fPvNJ1boUMYGMXktJ9kuIKbee0qLDDElybIa2X9Wi0uw6KvOTa1gZabSJKoRZM5bJXFRrFFVZbtPlKExY1LrJgiMbgQhbVSZIFmpuEd5zIFtAyTuwuD+ZIl2FSvJsakxbu85SPSfS3NpDTSgC1IS2yky+MshZEloQjaCghTA5tR4iD4agqe3mlqlkSO2dTqNvspF4y3OxIdcMdF3aFrJlc31AP5/c3kqbn5oGp0IfulsyToi20TvaJqKiqmIZqBDCoge1Oj81GlO3Q9ueOhcFqpWXmCrXlfDB23ZaRQdlMUxwguH75hCipYAebdavtWpqb6N6J1CjnjZtSoMSWLUEQLrH2j6iGIkESMfbpFRzADlLQg2SaeRJrS/kaXFtn4zkqVFPVX3NCJI6FzrO86jmHpkaoWWpfFUJz8ib3CSAEUQb8sKOQ35RBzf+9j2oVQU1Oy5ZPWthMaN6btqkDm3qRFUI+1BrOGnfbJKKVmsuaRpbvS9kI3HxHmgg49g9wBwI2P1oE20NCpte7+r21L7uNhF4dn1sygrUCGzQOs52UO8pwanjQ8decuUdHBwcHBwcHBy+SsjZSGgE0bQ0rao05+nz4O0zab2imNpVTejVWk/eBtLTtpMea/CiFxrNlAXeJCrNzp1NOlGty4qR6I5F32wJVim8kFWx8n6ZUEVUx6v7beWoJwPtyMMg+prSdLwKuR94yN6pajRLNUgP2h2IHReJtJL+CjxFb6OEV83/qWI+5NIQUSQml70weM+z4iwBaL8Lm2dWWHBm9aEjZ0mo16JJVXirULsZqeSXEU5mjaSmt7kbQMvpcpsaVl1Io70caUmB+kBmxEeto2Ljarf5l7GXtwrlQa0SaZlci+fEJmWp1qWpdacqxPtCTqsLH16USNp8EKlQCWLYIhcG+cNB7OOu1JOqZQYiTIPfKioSI7Z0Nve2WqKgPlPUUhib8hhlnM3v2Hv8rf1RLyCbbTvvu+8+zJs3D9XV1Rg6dCjuuecejBw5stnxixcvxqxZs7Bu3ToMGDAAt956K04++eTUf3/qqadQUVGBlStXYvPmzXjrrbcwbNiwtG0cd9xx+L//+7+0ZRdddBEqKioCHYOCnCWhCVOPhPny8Bkpk4U0cptNTVnPldGsZtX/o2THQQmn2PnI+7INW+RDo542liNq5FLuES1GLRhsokVBO4ioURub/tU25IWBWTnZHFvYvqPsXskXPlBtyFDYav5Wtthpdh9hQ9kHDT6q0Ucxq2PTDUuFTeTPxpkizE5f4gcrI/rtEtFo5tclwHV84oknUF5ejoqKCowaNQrz58/HuHHjsGbNGvTo0cM3fsWKFTj77LMxd+5cnHrqqVi0aBHGjx+PVatWYdCgQQCA7du3Y8yYMTjzzDMxZcqUZvc9ZcoU3HDDDal/d+jQIeP5Z4KcJaFJ4+0VTjwnabSQtSdkUTrND1JteUmjlEzoxNpb0vaJmnrfC2a9xOcrPswZMQ3bQFklnAxqREGO2oSoUFUjjez42YMxaH9ogB8DU5+z86Sm7GxgQ67oSzNdEBTJ0xo2UITtTckg2wyF/EqwIb9hCo4sIJdaqOfYJkMStieoTTtTBqVtZ1B3jbDvzSDIUjr+jjvuwJQpUzB58mQAQEVFBZYuXYqHHnoIM2bM8I2/66678O1vfxtXXnklAODGG2/E8uXLce+996aimOeddx4AYN26dXvcd4cOHVBaWprxnIOiHVzV9gGbDjqyT6iF4TwD9d1UFePiPrxEXE6pqy8QFrVSyRAlfhZROrXWM+xUMTsHDMo5ZQRRJXks6su2xyKXcpTSwtTdgoSpZu1B7Y1oGt+GRGUj0qbChuiqhNNmnFITSqA6EMhNHOhObJ5l5DhYswcbhbtNaY2yruh6QF0jstA1LGNYREJramrSFsfjccTj/o/Xuro6rFy5EjNnzmyyiSjGjh2LqqoquouqqiqUl5enLRs3bhyWLFmS2VwBPPbYY/jNb36D0tJSnHbaaZg1a1arRkNzloR6zeqjzOSdiXyYATlpxaiSQZryZ96UFo63vDRAu/T+0gAxHR92JMdGbKGCdjwRXxjZIA3edVXBiIqw2/qpIiQGFn0UWy/KZu3yPSWQofYRtNuNkM+7VfTNpqyCQUirU3JJ6nOZdyiFTd2kCvX+VIlkNmqP1evthRgd9ZWCtQcbJwsSWlZWlrZ49uzZmDNnjm/4pk2bkEgk0LNnz7TlPXv2xOrVq+kuqqur6fjq6uqMpnrOOeegb9++6N27N/7yl7/g6quvxpo1a/DUU09ltJ1MkLMk1MHBwcHBwcEhG/joo49QXFyc+jeLgrY1LrzwwtT/P3jwYPTq1QsnnngiPvjgAxx44IGtss+cJaFKaQfvJa6a2rMvW/8iNb2tiqS4bZOoLA8I2hFCtFmSozY2nXsY5O477B4Q6y7D/nL3bs+mG4lqOG9Tmxm29RIDvS9EA3exsw4tyVBEYmoZhE29Kpsbg030MWzxis31Fn5narcpuWFD2BFEpaQA0G2b1LptNdoaZt95Nl/xWenNfCgtpFsdkQA1of95PxYXF6eR0ObQrVs3xGIxbNy4MW35xo0bm63VLC0tzWi8ilGjRgEA/vGPfzgSGja8bTtZu0e1S49q5cTImrouG8fmwjom8T7uAfvJswdZHVE2qikLVfhjk05V96EiGVzhTlPKTPgQdH5qGtuCHNB0d0LsBMXGyd2byPlMkHvPJi0c9MNG3adN/R6dr1ifqwp/wiYlNjZiCvkH/MemXlcq9iTlHWK9byS/yD9OrkfXhukem+zjxKIWVy1z8v5u1fm2A/slCVlQxxcUFGD48OGorKzE+PHjAQDJZBKVlZWYOnUqXWf06NGorKzE9OnTU8uWL1+O0aNHZzZXD95++20AQK9evay2syfkLAn1tu3kY5iQKHh0SyWSfF3iO0ofooToQiTJzOjd+3BQXyCqTYoK1XIkbNEQgaknvoFEHa3WJhrizBC4v7Qq1GnYLo2zEoxkIwJro761qZvzLrOpkQw7EqyOY1EqG+Icdh24mhFRtheyspoSTgYbd4DW7uve3H4Z+Y+KxNH3u2CRUPGcqB+22USWLJrKy8sxadIkjBgxAiNHjsT8+fOxffv2lFp+4sSJ6NOnD+bOnQsAmDZtGo499ljcfvvtOOWUU/D444/jzTffxMKFC1Pb3Lx5M9avX4+PP/4YALBmzRoAu6OopaWl+OCDD7Bo0SKcfPLJ2HffffGXv/wFl19+OY455hgMGTIk42NQkbMk1GvRxNLdearqnfp6EvskprxUuy2x/DZ9wGvzi5J+79LDzCayod5uYRNJFpEUPFGbHUdeQKaOkDomHFIfSMrLO+zoQdjWWDbG52q5hJretvHEDPNahOkTC+hK67BN6BnCjgarpQZKaYRIBtXnQugirGyU0agf8lYNL4Tfgfh792aIItF2oJbPkkXTWWedhc8++wzXXXcdqqurMWzYMCxbtiwlPlq/fj2iTd4lRx55JBYtWoRrr70W11xzDQYMGIAlS5akPEIB4JlnnkmRWACYMGECgC8FUgUFBXjhhRdShLesrAynn346rr322oznnwkiRu0R+RVBTU0NSkpK8M6G+ehc/CWZiJOogJp6V7sjsXGsDIAh9NIAxlaVB4jqQ2mTdlS3x8AIp1ojxqDWeoadZmVQotIMNnWY6osm7ChYNtJzNh23FCsa1Ys1bE/HsK8Fgw2pDbtWWolKM4glH7qVEyGrYUc9VRKqfrCpxDRMqKl9z3mqqdmBkrIfYOvWrVJtZZho5A3/fvZHKO6YmaCoZnst9jm1ok3mvTegHXgeODg4ODg4ODg45BpyNh2fHzXIj34ZBGZRRRZ9VEEV42wc+Q5QuxnxTkUhm8l7v4qDrpfJPCxSeHK/7gYx2hp2PRyDWqivlEaE3cOdihQsom8sxWoT9bSJQKoCDCUtalPvrEat1PSsKuxTa0LpuhYRtGz8ppRtqc8yVntPhtHoqGxGagEqMCPjbKKoQSPfNmU6SplFtpGlmtBcQs6SUKUmlD1poqw0k5jV6/Pwv2yiJG1Pa0LpBkXFq5oSC1qfqaaxLQiNt3UigGY6/FgIHGx6p4dd/+e9trSvfcjWUyopY3WINvZOcv/zEGtsAVBFnXL+5Naj4jHI5R1qWYVIdBnCJpzqb0pVxwdN76s2RhY15VZG9+rHCUXYvwsC5TeqfsCE3aa3tZClmtBcQs6S0DBBvTlZi07xhUkjnEzAxB406g2vWrt4ob581G4+oo0RheyFZ0Ek2cORHZtKmtTInUIu1Guh1oKpvqM2NX3qR5J6nlQ7HrX9aNDjbauXqI0FkBoZa22iksm4oHOh54ndTyS7wgin2GYyErMwJJcFXFptK70J1Oi6Gm1VtiXbTLXH3vEuEho22sFVbRskTASJJgrxGO3rziyQWGpGu8lUJTyLysqpdzUoGzTlohIrixcXVamzNDuDVe94cZzqOyoLjsR0uc+sXiTI6jmxEVfZpPzZ74feZ+K9x8DWZVCPwxulsyFW7DqGTfLCjraqJuRqFFWNeuYJvdPDJK/NIUIETCwSqmYmbASA7LyzHvN0e+L1CUqkXDrekdAWkLMk1AtG6FQyqIJbNLG+8/51I+RS0fkFtV7KZFxQiNtiPpxWL2U1MsagRgVsVN82LyUvwu7cFLb1klpzqMImomvz4lcM0m0iWTYfdjZpe1IeJHcRYuPC7tQU1AJI/QAmSngZNs8oBrWzEvuIY1F+NYOjfkwoCDOK3g4ioZFIRNddNFnHoXk4iu7g4ODg4ODg4JB1tP2nRRtBMquPMvNy9mXnX6Sr2f3fAaoqn45jkQwVQdtbUm9Ov2hINYMP3ZtSrQdjsOnPziCXBgjHG/Y+GWyiImGLMhhYxCfsNpMMQVOFapkBgxxBs4iEqtdWdTQI28RfuWY27hpE7Mi6ocmRdZv7QhX72dSo2xjiB4WavWiP6fhIgHR82L64XzHkLAlVLJoYbNLxdHuESFJiGvJ+KZT0l9qhRgV7ObJ2ffL2xLkw8hK20tiGrDF492ElKhAJGKv9U0FrXcm1tWm9qa4rWxmFKDBSSYn6gWVTY0prhUMujWBQ7x/1HDAENasXfz9cCc+eAazsx+IVa/MBrD6PbBpqBL1mNg1A2hquJjR05CwJrU9GUJ9sIkwiNktJ+B8+zI6J1XywVpnMZokti1IXOgLZ1kN8mCn2NJTkkJZratTTJlJgI5JSYRO5YgjZF9UH9fiZf6EajbFpFanW7DLY+Gna3AMKiaf3iXhcNn3nZWU0gXp91JeoTcQnzAi5DaFVP9iyUVNvQ6ZtHCfUfUjXIsR7sT1EFJ1FU+jIWRLa4CGhhSTQmBfxv1hVO6ZsRFYNIb8R2qJSjBapLTQFWImLGNSXQ9iw8bSj59iCSHijSmErfm0iY+yc2ERjVKjXh5J/cX7KfRa2+E89d/K9GDIxC/veYwiaFpYdAwJe60ygklA2F/YRw1wE1BS92pzAxt9YAS0XIeO876NstPJtCS4SGjpyloTmedLxrLCTRUJZb3a1/pOty2pMWRSRdlYi5NcqCqS80ETjZvkrnj1oVf9GFpFUo2ryOQk7qkTWrScfDow0eN8rSXL92WlvINeHebHmiWbbdWTduOiHGHqNqcWL1cas3buuqhZnsIlQqQjqtgA0cxwhzy9MdwmbWl/xY5dZxsnK+rBJMntEqR82zMpJLWdh8CnayRj1o6Y9wkVCQ0fOktC6RBS1iS9/IVHi+1ZAOshE4P+BRshNRjswMU5Cu2z4F7EapBhNYal1P6KFh5cQMp9UYshskiQSSgmDaOti88VOU1Pi9mw6ELHIAyOEjPyxZf6dattnD0EiupMjbSpZVetOw7a2UVPKagkBQ5jijbDPkw3CtkJTI982PrsKxOcHI5K0JpQRztY23AfsMkKyZ6t43pXntFpnrnyc5rWHSGgkQCTUkdA9wcWJHRwcHBwcHBwcso6cjYR60/EF5GslL+qPlORH/VE/lmZnYONYjWmMtIRj41C7w79MtWdhKW/hS5lZL8mKeRKkk9e1EVHUMRsfMo6lxdm5Y5FAmvImIW22vbpWTkU1iNETmt4nx8rS8QVq1Ne/iN8DZC40oqtGkMgyNfrE7h9flMZCwMVgI9gL04Qf4CI2m9pem1pZZZzqBCCm2W2EWdTyKb9I2x6DfL+HXEMfNF1u087W+/tRS61aE64mNHTkLAn1CpMStPuQf1mCiBl4r3ftxqNWH2wc257amo29gdVaOu8ym3oelTSyVLma7pZbbzLSKK7LCBIdZzEXRv6827N5uMlzI8so0Rdf3my/tIRAJD42wg/2WwlqZ6Xazqj1n+r9zhA2MZVLA7JQ16dcb9khQ7XQYquSwANzCWFpe5t0fJjdjJqbC0PQ+1H9OHPq+JxFm5PQ++67D/PmzUN1dTWGDh2Ke+65ByNHjmx2/JYtW/DTn/4UTz31FDZv3oy+ffti/vz5OPnkk0OfG7NZUtXxaqcuJkwytMc8e4mQF79agxTUD1AtjLd56asI24qHQVWMq0RCjfBJ64rbV8Ge8UwgJR8riyKLTgCMhKsEziaqxuZH/RqF2rc2s1QKmXCqc7GBjb1R0CYODDYiuWyQJLVuO+yaYrWO00s62X2nfkzaRFFbCy4SGjralIQ+8cQTKC8vR0VFBUaNGoX58+dj3LhxWLNmDXr06OEbX1dXh29+85vo0aMHfve736FPnz748MMP0aVLl4z37VfH+8EinMzXk/mEUrJKSENUVNtTqOSSmg8HVAur1k4MNnY/qim5SvwYuVLT5+zcNYiRQLZfBoWsqe9aGlUlx6pGM9VjtfnAYHNhKCSqfDW9r4J9JCgCDBvY+Deqpuk2hFN1jWDjVBISlEjZdOqSf1NihJMh7Hsl7CiqWqaR7c5Krb0/Ba5jUuho06t6xx13YMqUKZg8eTIAoKKiAkuXLsVDDz2EGTNm+MY/9NBD2Lx5M1asWIH8/N0PgX79+gXatzcdH2eCXxrhZITTIqVOwIiuDBs7Jtp+UyQDXqjpOpt0kLoPOc0ech2VTcRUIclqmmeXmO5nkLtIiQp8te5U3a9KnNn85DIIFr32/DvsyKAaBWIlOWpXprBdKILa+GSyX4Ug2URVbc4Tg1V5gwibDm5qNkmNrivnJegzP+xuXkHg0vGho80oel1dHVauXImxY8d+OZloFGPHjkVVVRVd55lnnsHo0aNxySWXoGfPnhg0aBBuvvlmJBJM8bIbtbW1qKmpSftzcHBwcHBwcHBoW7RZJHTTpk1IJBLo2bNn2vKePXti9erVdJ1//vOfePHFF3Huuefiueeewz/+8Q/8+Mc/Rn19PWbPnk3XmTt3Lq6//nrf8tpkBPlNIqH5CcbHiSExi4QS9SjtCU/GqXWniYj/q5P6hFL/Pv8iCvKlGYl51PHipuT0TR1R+LNidjUyRKNWqp+oTcrWIh2truv9oqYRP4sSAAZ1e0nxHKsRWPXYaD2teB3VueQLqVe15akaoaJ11mKa1Krdo0V0VI38qx1+1FIDbxchm+ieClEJz46f+Srr7gVkLmpdtE2mJ2grVKuOTJ59tgdDe1cTGjraQZGFjmQyiR49emDhwoWIxWIYPnw4NmzYgHnz5jVLQmfOnIny8vLUv2tqalBWVoYo0sPAeYSA5FOrJPJQIaRR7ZhElfXMmJ4Kk4h1jErWxHG0G5JvHqrgSK1htWh3yaDWF9LUrmgLpNai2qToveOSTFUuqs+ZX1YBIVsqgbUhuirULk9qyl9+AQv3RdAaa8CuXMTGMcCqOYNIJBlUMs2gpNrV+dLrT57vTBvArJdUA3sGG9sqlZgyMCKu2veFCeXjrL3UhGZaOuFqQveINruq3bp1QywWw8aNG9OWb9y4EaWlpXSdXr16IT8/H7HYlw+FQw45BNXV1airq0NBgf8hGI/HERdaCjYkCdkg7y1GOBkx5aIm9rRg9k5iDYlqlyTWiFHCKXRMopCtc5jdj9j1R+0RLivc1ahaFggn3Z7wQaCSLdUqSUXYNabyfkW1eVCiD2g1qzYvGpvIXdhtUOUuZBadkBjkulPmASoQWFrTyDI/muezTDhtzieDrI63qP9UPyYYvuo1oY6Eho42I6EFBQUYPnw4KisrMX78eAC7I52VlZWYOnUqXeeoo47CokWLkEwmEf3Py+Lvf/87evXqRQnonpDvUccX5RFRDrVKEsVF5IcWjfjnKCvh5RcrWdcmIuN9OFITerFAnY1TDYhVgqQSTnUfNp6gYSvLleiGjdcpjayqH0QigQ87sqy6HLAIsbpf1kzAS0xZowP6QWBhnxS2ybtN+0z1922Tjqc+q+R54RVnWVheyT3hw/4Yt1G4q360DCo/CnqfybZiZJ/tgXR64dLxoaNN49vl5eWYNGkSRowYgZEjR2L+/PnYvn17Si0/ceJE9OnTB3PnzgUAXHzxxbj33nsxbdo0XHrppXj//fdx880347LLLst431vrY2io2zMB7JRPegaLveNZ7Wg9S58T8BQ9KQ1QXyxMQUvIH9ueqff0gLd5mdl0i1Efqmr3IVpfKEY96fbECKSatQ764LKJyKrRTJv9UpNvNo5Zcll0VmLbY1A9S4NaPoVOEMX7RO31btOHXP0AtOlPzwix9zhU0kyOgabeWe/4sKNbYfu4qh9YckMFbbfSPtlvca9Jx0cCREKdOn5PaNOretZZZ+Gzzz7Dddddh+rqagwbNgzLli1LiZXWr1+fingCQFlZGf7whz/g8ssvx5AhQ9CnTx9MmzYNV199dVsdgoODg4ODg4ODQwBEDDPD/AqjpqYGJSUl+NPae9Cp+Msevl0K/F+7cdIloiDq7/vLIpd5JPWeH/VHJOm6pGe93OqMjWNfmSStRdWd3i90NUqppu1pip6khW3qFVkErc5CcGOhhKc/t4Cp90jMwl/TRn1Oe8ezCKcYpaTrisehRkwZ2H5VKJFVdgxsn2HXUqpKc5tontqfXYWq8FYQtjcnU7izMqqgvqbNjbNxKlDvH7Wtptpy2Ttnmw5+nshnTc0OlJSeg61bt6K4uNg/vhXRyBu2rPk5ijv7OcAe1/1iJ7oc/JM2mffegHYQ324b1CajabZM2+u1VGw04v+B5sH/Q64HUU8Si6YY2AvYvyhCCHGM8TKLonKWNDBe6ypqkSLWuVFjbWLRxNLiheShrxqVs3GqQTrryKPWDRLCEWFkmlkAkX0Yjx+u998AEGHbYlAN8tnxq+SNkpIs1EepBDts43wvbAinjb2TDYFV52Ij/mKwsUvyQiSDqgiJNuxgH4A2bVVtrpmo8pehdlFSCKZM9Mky7/3eHtp4OmFS6MhZEhqLGMQiX76EuEWTFiRWI6GqlRNdxqT6Ddv8y0SrE9Ow079QafkpR0LZOFITG7RbUHPYQUitWq9IwcRKNkKnlslls/CuS14+hglkRMLA6pjNTq2OmUZlGRjJUz8mGML2QA06jjorsN+T+BFC6zXF1rXsY49BJU2yt6loIaUKGYNGL0XCJNnPQSerFGHXALMPAmappHbSYrARu3mvrU3NstxDNYtwwqTQkbMk1Nu2M2n8LxFZGE0eZsmIfxklkgS0vadYWC/77YUJG48/NUJF9yuKSFjqPeSUOiNhlFyGmbKkFjNkHhYm1YyYyrCxhmJPJjXCqYLORVzXNz9V6CWmJ22sfWxsm1RSa3Mfh20hpUTf1I9nm9ajNql3Fez6qOUXsgBQzHYx+JwKLJwf2qUwyUVCw0Y7uKptg6SJpBHPWmIx00ncFuv1TokpiaqpBJZ2R2JdakQLmAgx4jdJoSaUQbXcyEY6RU2xkugbTW+LJEyOZqpoJ/Ykag1rRPbrFF0EqLOA2OudjVMV80HTzCrhZi9SFr3ODzmdqlovqR+2Nk4XYddsBo2OqqTGJn2u+qnS+10knDLJETJdgJ6Op01KhMxZUJ9Ql47/SsKdHQcHBwcHBwcHh6wjZyOhXhTFtK8ztW0nqwnNi/pFHqwdJ03bq0bvVB0vKuGVFA4TJqmqd/a1v0OsE2VQfQnVtp2kvtCE3RNejdIx+DqI+KOURhVh7fJff1oEwYRJDKqLgOpyoAY9xLpb+r1t43zgjeiq22Kw6Xwm+zwGFJYAPB2vprJtTN1tjPOVeagdk8TucrJinh6r6HJAo+sW94B6PoNGtG3qUL3HlWPp+Pvuuw/z5s1DdXU1hg4dinvuuQcjR45sdvzixYsxa9YsrFu3DgMGDMCtt96Kk08+OfXfn3rqKVRUVGDlypXYvHkz3nrrLQwbNixtG7t27cIVV1yBxx9/HLW1tRg3bhwWLFiQss1sDbSDq9o+wGpC60mKvkAkqzz17n8wMBESWzcWdscTJq1XzKtZCYD6I5NrOMW6wZBT1kxZTtPR6lxEYZtcA+slq4z4qMr1oO0pm1lGyTorb2ACJtUMXhb/sJKMkA3sJbFSyIr8bAgcwkyBNwf1GRV0LiJZoaRR2T4AMMU8a8TAoDoB2LghqJZXNjWr6nVUtq+k2ttDOh4BSGiAhPMTTzyB8vJyVFRUYNSoUZg/fz7GjRuHNWvWoEePHr7xK1aswNlnn425c+fi1FNPxaJFizB+/HisWrUKgwYNAgBs374dY8aMwZlnnokpU6bQ/V5++eVYunQpFi9ejJKSEkydOhX//d//jT/96U8ZH4OKnPUJXf7+AnRs4vfVmXRHKibeoR3ytGhmjHQ4Koj5/cWYdyhDjH0v1DElOPmhEvWkrxMS0ExXFc+6akRWFSuxyB1bJvdXD05gqbLcxofSSpUvwKaTDyV0KvGzIFLkfNK6W3beaV2nKESjLTTFFqIKGPlXfVdVsMiYjfrYxlLIxtcz7FajXqjKfeYQEWOWbCyqaNF+VT2f6oeIjU+oDYKKxOT60vRt1dTsQEmfiW3rE/rhQhQXZ+gTWrMTXfpemNG8R40ahcMPPxz33nsvACCZTKKsrAyXXnopZsyY4Rt/1llnYfv27Xj22WdTy4444ggMGzYMFRUVaWPXrVuH/v37+yKhW7duRffu3bFo0SKcccYZAIDVq1fjkEMOQVVVFY444oiMjltFzkZCdzREEGn48oVQRCIlOxv8P5ZYxE/oOuT5JUyGFIEnkn6Sw8gqS++bCFFk+5ZAfuhFCJmmfngK1B72agSRik3I9pjdjYUISU+fqj3R2RUSI5BBI79sn3RbZF1K1Ni1tTh3BNRtQG35qYqVbPw/FTJpY/ekkg1aLqPtlgZj1FSsPGd1LhbCF4WEWniYGubvzCKmYUdzw3Y0sPCLpgj64aBeVxpBDWj31IqIRKJ6BL3JOpmgrq4OK1euxMyZM1PLotEoxo4di6qqKrpOVVUVysvL05aNGzcOS5Yskfe7cuVK1NfXY+zYsallAwcOxP777+9IaDaQR0ge8wmlaXERMaJIZ2A3bUQNWNMXFalpYoRTtfVQYBPxs7FoaiB1kuzc0XpFsXuTGh1UI6ZBe6Kr2SnZmoUdq8W1IKBRT+Z3ylL55KNQJt0MajcoSmo950/tIqWSZrX2j4ESJNHGh4ERJLHGkpM/C0u3oJC9ci06IamEi8GGrKqK+Wz4hLaFRWA2YVETWlNTk7Y4Ho8jHvcHgzZt2oREIuGrw+zZsydWr15Nd1FdXU3HV1dXy9Osrq5GQUEBunTpYrWdTJGzJNRr0cR+nqwmNEEK0plFExMwsUgoix6wOlHZ1iOhptjIsnryRe2FkDYBwL/iaVtQm5S6ui6xqGL1n2rUj5I1ixemahWkjFHrHBnxVeehtqNkQi+LFD0rl4jEWYraJuUdsMSBeoKKUC3OGGzIAb0WFi0/bVL0YhSRCYd8Ikv6XGTH71/E9EFgIiSWtmdQO1+psGkmwKAS0zAjv3szebUgoWVlZWmLZ8+ejTlz5oQ0sb0XOUtCHRwcHBwcHByygY8++iitJpRFQQGgW7duiMVi2LhxY9ryjRs3orS0lK5TWlqa0fjmtlFXV4ctW7akRUMz3U6mcCT0P2Dq+GjE//XH0vGGtYQj6zKwiCnbHoiyXo5SqVDbAvr2KaaxVaiWSjQtzsZp6V46ZzmlHrD1JpBBjanXqYCsR1tgipFBNaWsrqtGTNk4EqmmdaIM9HyScXKto1rvq2xLjdKRYyiwELQw2FgqMYRtF0VAy4i8+xUjgyz1Tq2XWBmVamNEn0diLa5aBpGNWklViKUgqPVWezB9t4iEFhcXS8KkgoICDB8+HJWVlRg/fjyA3cKkyspKTJ06la4zevRoVFZWYvr06ally5cvx+jRo+VpDh8+HPn5+aisrMTpp58OAFizZg3Wr1+f0XYyRc6SUG/bztqE/6VSFPMvqyMPeJvmJsy2KY/Vb6kPAbVXs4qg6ZWo6NXIerOrquJsED8GNW2vCqxUeIle6OprMaWsislYdyQG9SOBIKLWDzOo948iHCsQyYGsyLchORZpcRtVtSpqijH/S4v2jopFEy0Z8l9XSjgZbNLTagtmFWHb99l0OVKEScxrWikXaQ8WTVnqHV9eXo5JkyZhxIgRGDlyJObPn4/t27dj8uTJAICJEyeiT58+mDt3LgBg2rRpOPbYY3H77bfjlFNOweOPP44333wTCxcuTG1z8+bNWL9+PT7++GMAuwkmsDsCWlpaipKSElxwwQUoLy9H165dUVxcjEsvvRSjR49uNVESkMMkNChiRMDEakJVMAN7GglVuYX8ha6aQ3ujDOI+bSArvFU7IvElr25PFitp9amySEiBDSljp0kV6qhzUV18CEHgjgYW4iIVipiI3hMWkWCb6JZNrTD7Laum6Qw2AhmVTCrTYCb0DGr9KyPS6vYYbOyY6PYsBFGhCsJEJTwl5sLHRbaRJbP6s846C5999hmuu+46VFdXY9iwYVi2bFlKfLR+/XpEm9wfRx55JBYtWoRrr70W11xzDQYMGIAlS5akPEIB4JlnnkmRWACYMGECgPTa1DvvvBPRaBSnn356mll9ayJnfUIXv/MLdGjiE7pvof/HUky8Q/eJ+08Xi1zSZaSLUkHU7znGVPTMi5T5f2LXNv8yVi7ARELsC5V5kXrBUvakI49MkNR1xVQ59f8M2/xeVmSL0TFFdKWKaNRuPjam8SppUpTmaEaYxKycmPm9WgagIqhFk7pP1RpLfZmxDkc2Nj42IiSV1DGopu4ByYmcZlfnwWDhWWr1cRJ25DvoObCx3vLMraZmB0pKz2lTn9CtGx9HcXGHDNfdgZKeE9pk3nsD2sGnRdsgadLfr0wJz5AgP1p2EpkSPp8oKlkUlUZCGdT0D30QMosm1lnJ86JixJeaL7O6ThvbJu2BrKZx5U5NDKrNDgOtnQxo/q7Ol8Gmmw8l3OILU3QWkL+MSe2o1bExKL6jKgm1+dCRo5SijY9KEOk48dURdgemgFG6TL0dg+0k5JpF2XtWjCzanGM1sup9X9io773LsmHA3xKy2LYzV+DOjoODg4ODg4ODQ9aRs5HQeo8wiaGW9Emvp8Ikf3SHpeMTrHe8IcXxrACU+teJBe42X6Pe7SUtuqywZSz1roJEqNgVpVE1tl9WN6fWF1pE/ZAnpuK825MjJazGVuw7bxP1pXNR58w8IkOOZqkCKxq9VGpCLUov2GlSDeLVZ4BNZKy1W282M476hHoU83QMVb1b1MnKCnc1ehdylFIVf6kIM73fVm1Gw0CWhEm5hJwloV40iOl4NW2viotodyRqA0Xq4WyK3tW0oDf9TvvLM7Vj2PZRWoqVpuOpbZFa4C8qkpliS1Zfk2VBxTAqkQxahwrwMgB1v0xFTo/Vv4gZ3UeKWP2jWO9byPqEizWw3n1YEXP2YAhO1mVbJJUg2nQ9YmTIghAzMumbHyU5ZD0bYm5DmmxaatqUUDCE/YERdB57CyKRAOl4C8eSHEDOktBYZPdfI+Ix/8snSpTwbBwDI5JUXETAOibR6CgDI4RhPljZy8fGOkOtkWSg7S7Vuk61RWfAKCWgq/zZ6VNJogKr1pY2EV4WMbYgvwyyp2zIwg9lHD0GMTqqRt/Yb48Jk8KGSl6Y2NEmOqpcM3Lu5P7vNtFhm4iX2vXI5lluI7oKul+b+Sp2T9mGqwkNHRmT0JdeegnHH388/W+/+MUvcNFFF1lPKhtImN1/jWARThYdZctMTHthJkm0LAaWOmIPC7JB9gBRBQhBfe7C/rJVhS82dkxMhKWuy+a3w8IQXoVyvGrJg41pPIMi1GlueyyKysogbEzY5XWD+5j6PmJYVFWOtrNrzeYrEk4bchU2bAin2ttdgGxCT1fOwvkM2yfUBqp6ncH7DrGxaGqPcCQ0dGR853/729/GZZddhptvvhn5/1Gmbtq0CZMnT8Yf//jHvYaEesHeAwnSRYlB7npE9+t/ODKyStPx0h4yAHtghNm1gka8gjuEUXcxtcNR2J6gau1kmBE51YeUnncLr1M2N5vaWbkOUwMryeBWTmItpo3vpgL1A8amb7j60akS57BT2XK9a8tdjhi5lJ+VYTf7sOk0RKOtbH4k2hy0rzug+8Iq7zjqUS28Z9RtZRuOhIaOjM/OSy+9hKeffhqHH3443n33XSxduhSDBg1CTU0N3n777VaYooODg4ODg4ODw1cNGX/OH3nkkXj77bfxox/9CIcddhiSySRuvPFGXHXVVdRgur2iIdmyGCk/6o+UJGgwT1PHM9D6T/XLidV/qj2Ig7ZmY/u0SYGrIBE+amjO1rWpE1VrIm0EQexbUOkBb9MdSY3u2bQyFQVHcn2qaP5On0Ot3ZVIFWsx5FmkhdXyGzUFSqsWxKha2F6f4vWR0upqVFFd12acirDN/1WoPeuV/dq0Mm2XveOReQpy76FFbYJAOaW///3vePPNN7Hffvvh448/xpo1a7Bjxw507Ngx7Pm1GqKRYGV8jJhGieCICYmiYg0aW1cm+FTNraaoA6Y7qNgmYNtJgBIwq8ZeKllTa/NUYRIdp9Y1CtfbJn3OYGPbpPZOZ7D4SDBityW5DIJ+mLJj8/xbVcfb1LqqqWLanEKszbNJxzNY1P9FWEtj8nHr+2i3SYHb1B2rolCV1IdteaXCpqzCe2zMokpN0SvrZRnGmIzfRTnWlDJjZExCb7nlFsyePRsXXngh5s2bh3/84x8477zzMGTIEPzmN7/B6NGjW2OerQ4W4WRIinWiDcb/QMqDP2rBfELZl5NVTahVZ6WA9VDyS9/CZkrtjKOuqyq3WcRQPTZ1fkoLTfmcWHh9qvMNW2ku3hdye0858qvW7ArRN5t7ll0fG3KgPgNoHaaFvZOFVRAnnILKXRXWqMRchU3EWJ0z2wfNTll4oNpcx6CBDMXXtB1EQg2SMBl2/st0fK4h41/cXXfdhSVLluCkk04CAAwaNAivv/46rrnmGhx33HGorbUwH88ivG07Y4R/1Cb8C+vJQJaOZ3ZM9OVIXkp8e2J7QuobaPGgUWwybPw/swG1JzojJczU3kpcJNr2BFXHy24DIilTj4uWHoQs6GHtPa0i5KJiXhJOWQiu1A8CtTlF2Ai7n7r47ImwZ15QImJzDCq5tDGDV+fCCCeDWm6lepbSUqUQ7729xKLJ/Od/ma7j0Dwyvov++te/olu3bmnL8vPzMW/ePJx66qmhTSzbYISziJbIiT6hTDFPCAjrMR8jHT8o1OiGTU1T0BSIWvsokiZG1k29WHeqpuNVwqlCJathqsNtPgjkyJ1Y/2iTjlfPSdjdSFRvT2+XK3a/qx8/qpG8TaTNpg5TTamHHEWktkrKtmzIpUrebFLDNnWdYdsbqb8pRlbV2tGvEIxJys43TddxaB4ZPzW8BLQpjj32WKvJZBN1Bshrcm/kkVpPlnrn6fjgxJR3R2JRAWYLJAqT2Bd6UJsQOS2jpnHFrkcqwk7vM3KhdmBSDfHZrzCogbtMaMUoLd0e25worlLFOhaiq6wIJL3nNCoSTmVbgG4VxcZZNafIgggnTMERX9G/zEawafNBwJANeyv1w4Z2u9OmItUo0xIvcQde6ymbSHNIcJHQ8PHV/mxxcHBwcHBwcHBol8jZtp3GUxPK7Zq0LxiWZmfWS9yEnn3Zkp3Q7DZLlYrptKA1TVlILVCxCUu9qzWxrK+7aiulRv1EGKKAi7D5MUgG9qLIRd0+NaFn4ghWM2bh6BB2ml2+V9i6wnHQCLeF5RWdR8hRSptaR/ZMsahPDdoJaffKAe+VbNTTMoQtGlLFSmGXabDfrXd7asq+HUQ5FexWx2eajneR0D0hZ0loUQzo0OToWU/4OGnHyWpCGeFkN2qMWI6oPqGyOp796FlHDZq2J8vqd6X/W7YiEjv3MDsmRjjVrj+0F71oqaSm1C0QYQo4Eabe0xlG3ZZMpC38VFXBkVoaYKGODx1KCYWN4EolQypZt0n3tlEvepMgSvg80gpVSbXT8xmuct+qNjPselJ2zVT/2DB1AGxd1VZMsQvLb/vaSpeODx85S0K9YL3j1X7yTM1uCFlV1fF6dFT7KjYNO8nKIoJGCyjJI+PYizVB9qmYtwPN1DCKfpo23qYMYk0gi44yeEknjarmkw+ieib0IjtQo6iUDIn2VrI1liZ+or8p1azeJkKsgJ4ni4gfgxpVU49BVUvbtBBlm1PFmIogSp2HxXxDN41XiZ9NW1GVcNqIG73zU6Klzc7D++/2QEKdRVPYyFkSWp/c/dcIVfXO3pcsmqmCipBo72NRmESWse0ZkHXZA84bCWUPPDX6yERItawwPuQfLZsfg2o9RNfViA8ljiSiqRJT33qEcKrjGIHVhV7ifFVDfJuSB0ZMWe94tdSAvUS8ZJIR2rBLCmxA/T8tonniBzBLszPVu1Wvc2EedFuMcGfDDJ7BRsDDYBPNtLlvve8HG2cWbwaPZfSyDKeODx85S0K/aAAamryDOpCXCEvR15Ko2k5CVDrk+ZexmtCE8b+Uo8Z/WZIR/7rRgg6+ZfTho3rLMXgfjmz7BSzixUgzIT5x8iJgKXo2NxvSqIIRJAtSKz+TCfHxElOajreylLLYnhoxVdXxDDZtShlCJr8+0OirxfHbdDNSCaeVCb2oemfENOG3R6Pq+KA2UGGnolVFOgPNwpBntFoGoUZvbUoy1H0okVC125TXwD7W9nWjLh0fPtrRp7qDg4ODg4ODg0OuIGcjoUWx3X+NYD6hDIy1s4gpA+0nL6bymccoSPRABvsq9qbe2TgaaRVT9HVMrEQifkxNqCq31VaZDHLrTTLneov+9HI9pSDCsSgB4AIh1ieeRKhsxF/q9kSFLk29M6ierdS8O6/lMex6hZ3atTGNV3us26SKxeONqOn4MLtG2bgI0FaZYj0tAzt+m4ikem3puha1o4owif62hUh9O0hru3R8+MhZEpowLfeLZ8b0MVI7yrotdWC/KRKWZ6Km/Kg/DcOKm2k/bBVqC7e6HcH34UUBe5gTIsDM6mn9noVoRq711CyfInHt2IxKiAViyrbE6jpZap9CJX6ySExtRSgSXRGs2QETAMqgXb08L3RKpNkVEpXMqiVOvoWQiMGCDEl93TOZi0o4vedFVYarc1PV5+rz2IZIy81CRGsk1bCf3qNsx56FClEFmhFctT/y5tLx4cOR0P+AvS/qyW+glvhwxknkjpFGtRNSQ9If4YyxHyl9IIvdkdjLRiGcai2U+hJVYdPaUiVDaicktT87G6cKk9jNJ0RCqeDIwhZKrhOVPTfFiKEN5Hapgi1Mc+sqJFmNcKv1lTbn2EYgYuOnmQ2hj/d4VcLprTlsblzY3aEiISvwbcap87Mh7GGt1w7g1PHhI2dJqAL13c1S+SyjzL6I2DLmE0rT8UzRzx4WrMCd+PLxNIznFqFpKAsSqkYkbcap7TNVL1IRpjb4gzWST6y7GDH1QiV0zBqMWjmJ56TAJs0uRBqb254NMVNB72XPeZHJOtm+Gt2SvSktIo00jctKHsi1sIlwKj6RzcxFIsnquctCO1IZNh8ONgIrFdTvU4iEqlHadpmONxmbzzuz+j0jZ0loLJJOMlnqnZFL1cqJwcrKiaXj6U4Em6VM4H1gyMb0FrWZKrhflrauTWRVJMQ21kthEk6zi8w3n3z8sJIC1YGAnbtdpGaZenOSZWI6nnbXYuUcap1omFFZtbzBph5S3i8ZZxNVU2ETCaW2Sha2Ul7YqMrVyJ0NMZU/EkRSZ7MPNYqqwMbov43hIqHhw6njHRwcHBwcHBwcso6cjYTWJYGmXTlZgIoJjhhYgI9GaMROSCz1Toub1donBlXtGDQKoirNGWhbSLEO08pwXkv/6WpzDTTqyW5I7z7YeiSNz0DnSyPVIZtey16KYsRYdVJgsChdkNbNRm9ym5aaYpRS7mbUVqls737VeaiRagYqTAq5rrGtItXqcSjnT/69k2XtMDrqhEnhI2dJaNKkv+dYOp6Bt+1khFOr/+RzYw+BkLt7qB1EmK2Sbz2xzSZLsapE0sI+iRJY1XCekCG1sxBNvaslCYxMetclRJKWANgkPNQe82oLTFoGwc6ThUVTnPQct7GQUupYqaWWRarTpoe7zTNAhc0+bEz3lXS0TYtOWnNqQQZVImljtaXCxiGAQnj+2hDJoB8XrQhn0RQ+cpaE1nsioewdsrPBf9N3ytduqAS58djNyCyaGNi4WJS8bNV2eknSylNR0ReQF+Eu0pveotMOs54yNlELBtVSh9a7Buv1DjRjqySq471iJbm1J5uvjWLeRlUe1P8UsGuNKfeJZ1ZTAfcb9gehuq5NPaC6PRuCpIqLbPxOle2r46zqc8V6TZvzGbYISb0vlHOq1s4qBLkdkDmDzCObLg66Z+QsCe2cDxQ14SHMcL64wH/Tcy2Mf9180nIuxnrCk7Q98wmNMduMBiI4Yg89omg39YQ4snW9/XqZOl5Ns1MDe9Kis14bR4nADlGEpUbGLGyl1Ego9Q4l61LBjRdMgENuWiZWirBAG4n8yxHo0BXpTKVNWkCS+4ca2NuUBnhXZQ8G1W+RlsawyJ2FpZAYkaRen2rKVlW4M6jnQCXnXsj+nyGr6FVyaeOQwM6TOmcbURODd12bdHxL224LBIiEtgfy3J6RsyTU6xNaT162rCa0mPRJZ+smjP9FmDT+Bygjl7Kajqb2hK5HzS1jCErCZGWwkHYG9FS+qnpnh2WRAmaQ/T/FOk4fMbVQ38uRUDVyaVOHqabjSd91VvbCPuzkuQRW/IrRLTUtrJIXmyidVZmGSDhtyJpapuAlpkHXAyyN9C1M8q28Ptl9xoIFFoRTJbDKvRfUwD4bNdYtwNWEho+2v6ptBKVjEotwMmLakX2cqo1xmD2NCpsoiJoS8vm+kXmoRNWmhlONUrJ1mUWRTFZDfoCogiB2c3rJKqtXFcklPSpVgGPVBlUktaLHqNyis7X9RGWyZVE3yF7CKhkK28OSwYY0hdmOk22f1XWqJEq10LJJvatiIDVizBC2vZOyD5Y5Y2D79K6rbqsV4Syawkc7iG8D9913H/r164fCwkKMGjUKr7/+urTe448/jkgkgvHjx7fuBB0cHBwcHBwcsoRMedHixYsxcOBAFBYWYvDgwXjuuefS/rsxBtdddx169eqFoqIijB07Fu+//37amH79+iESiaT93XLLLaEfW1O0eST0iSeeQHl5OSoqKjBq1CjMnz8f48aNw5o1a9CjR49m11u3bh1+8pOf4Oijjw6039oEEGkSIGPBstqEn6PvE/cPZIr5BFXM+7+IGoz/6y5q/NEdJkyiXxCtbSxto3YNu3+3XJuopoDFKF3YH7YhGqSr6Xiq8FfBrgW9PhYWSGG32VSFSdQejOzXuz1Z8RywphGwKqthtZ6GiSJVAUrY9kE0siiWxwStubNJvYfd9Ykh7FIGGwGTWgPs3Yf6zKfdl/L2/O82QLY6JmXKi1asWIGzzz4bc+fOxamnnopFixZh/PjxWLVqFQYNGgQAuO2223D33XfjV7/6Ffr3749Zs2Zh3LhxePfdd1FY+KUg4IYbbsCUKVNS/+7cuXPG888EEdPGPaVGjRqFww8/HPfeey8AIJlMoqysDJdeeilmzJhB10kkEjjmmGPwgx/8AK+++iq2bNmCJUuWSPurqalBSUkJ5q5YiMJOHVLLuxNRRmfSVaZ7kf/H0jnf/zDfJ+4fx8RKBbEi/7iIX/VeEOvgHweSsmS1QEyYtHOrfxx70Hj7ybPuSyxVrlogESERFSbZCIlUsRIjIAxqDSOBbNvEiKm3npRZNKn1pQw2fd0LRVsksRMS3R653pE4IXUKaWx2v2Iq27ssX2xjqdaOqsRPTRUzERJ5HtmImihsxFTqufJuj+2TPRdVhK1ctyndsOmxHraKXmnhbCHs9M63pmYHSnpOwNatW1FcXBx8uwHQyBver74bnYv97+w94YuanRhQellG886UF5111lnYvn07nn322dSyI444AsOGDUNFRQWMMejduzeuuOIK/OQnPwEAbN26FT179sQjjzyCCRMmANgdCZ0+fTqmT5+e0THaoE3T8XV1dVi5ciXGjh2bWhaNRjF27FhUVVU1u94NN9yAHj164IILLgi872gk/Y+PMdJfjPzx7cV8f7FInu/PCsmk/49OJur/Szb4/4KiocH/V1fv/2PzZeuyceqf92JnEnlsNJRt8mfqE74/dj5Nwvj++LUQ59fYa9bbc7YJ2D7pPPJi/j967sh9Qu8d8b4j55OiIeH/Y3Nmc6H7Fe8Ltl+2D+88Eg3+P4ZI1P9nkv4/enHFceo5Ydtjf+r2bI5NhbI99TkWK/D/RfP8f+y4VLB12fzoOHbPkvnZ/LFzoCKvwP/nPYZYXvC/dojGSGimf5kgCC+qqqpKGw8A48aNS41fu3Ytqqur08aUlJRg1KhRvm3ecsst2HffffGNb3wD8+bNQ4ManAmINr3SmzZtQiKRQM+ePdOW9+zZE6tXr6br/PGPf8Qvf/lLvP3229I+amtrUVv7ZQ/rmpoaOo690/NJ73gVLB3PwP0/iZUT7xQvToZZL4kF6d6HPC0gJz3CZWUW2Z5qfk9N6C1EUhYwteIPVREcAbwbkheqWb0afVVLHsIWuYgRWFmEpJjLZzIXGoX3/LuARG7VaCYD2yczq1fT8WrU0yZapkZ0aepV3F5Q2JQPsEcvNfwV141YeMAyyH3dWWSVRDPDFJip95NNhDeLsBEmeblHPB5HnDTWCMKLqqur6fjq6urUf29c1twYALjssstw2GGHoWvXrlixYgVmzpyJTz75BHfccYdyqIHQPj83msEXX3yB8847Dw888AC6desmrTN37lxcf/31vuXeIAx7TzMiyeyYGHhnJT/hZPYNrHaU2jzIKTHx5aDU/dSJqW2bGk725WVT1xkyqAJdtTwiXEC2bRKOTbZoYrAhpkn2MaF5llKEXT+s2oOptk3efai1imrNpUjA5JaaNqrvsGsnbci5QpDUGsmQ24xSuzAxGMGe7zTwQDcnfjjZqPdVVb4StGCwshrLHvaUvNnTOgBQVlaWtnz27NmYM2dOOBMLCeXl5an/f8iQISgoKMBFF12EuXPnUsIcBtqUhHbr1g2xWAwbN25MW75x40aUlpb6xn/wwQdYt24dTjvttNSy5H9+bHl5eVizZg0OPPDAtHVmzpyZdmJramp8N4ODg4ODg4ODQ2vho48+SqsJbY7UZcqLAKC0tHSP4xv/78aNG9GrV6+0McOGDWt2zqNGjUJDQwPWrVuHgw8+uPmDs0CbktCCggIMHz4clZWVKZulZDKJyspKTJ061Td+4MCB+Otf/5q27Nprr8UXX3yBu+66i5LL5kLeCmy6HTKP0SjrRkIgm9uyYbSYPcSCdGbwTNtdMtFQyFFKm4ipjWpVXFeOSqrRQc/2Ih2IeXu91gaWgkZPAkYGmxvHIDYiMFRcYxHhlCOmwvWhinwLJbwYuWMKd7nrkeo7auNHrPZiVxG0x3rYJSTqo0z8aUfYM4o1XbBpqGHTvSnsSHKQ7beDjkkJE5FL7ZquAwDFxcWSMClTXgQAo0ePRmVlZZqgaPny5Rg9ejQAoH///igtLUVlZWWKdNbU1OC1117DxRdf3Oxc3n77bUSj0T06FdmizdPx5eXlmDRpEkaMGIGRI0di/vz52L59OyZPngwAmDhxIvr06YO5c+eisLAwZTfQiC5dugCAb3lL8JrVq+RSTdvTLn6klkSt9aTjaIok5BSGN+XCUjBWVkns5c1alIp1ZOrxqz3hLcza5Y5JKry94wnhlPdpYz2lXm8bdbyNpRJNqYsNEArEMoigRMpG4R52q0gbtAO7nIxAazhtNmcRKGBQ0/vqPmwM8cO23/KC1X8qxNdGKBsSbNLxmSATXgQA06ZNw7HHHovbb78dp5xyCh5//HG8+eabWLhwIQAgEolg+vTp+NnPfoYBAwakLJp69+6dIrpVVVV47bXXcPzxx6Nz586oqqrC5Zdfju9///vYZ599Mj8IEW3+JDnrrLPw2Wef4brrrkN1dTWGDRuGZcuWpQpo169fj2g2un0QMJ/QJPmxsHc3qwllDxW1Fog+9Cy+DKlvIB0o7EPumMRqDsktWEuETgw2JJTtVyRSVIREhTSMDIXnCaqC2japLU8Z1FpclXCK0R3ajpN2VhIj3zbRW6+NWFFHbT2b2lEyLkIs3ihUgZBsDZUFQhyUiLOfGPsdqwEAct/xgIImTEoY/3lntaN0e3IvdrEWl41Tz7tSs6kKjpTWo+2gRjRpIkhmGAnNdDyQOS868sgjsWjRIlx77bW45pprMGDAACxZsiQtOHfVVVdh+/btuPDCC7FlyxaMGTMGy5YtS3mExuNxPP7445gzZw5qa2vRv39/XH755WnljK2BNvcJzTYy8QntkOc/NfsW+l/e3Qv9yu2iPP8PpiPxEswjSsn8qH8yBVH/y4Z5jFIfT3GZqd/pH+f1CWXecMzXcwfZlpi2NyoJZYp5towRLgYxOqqmvEOPhIZIYKlZvUpCbQid2qKTfSQQRJg/J5uL2AaUz0U4NpZ6V9tnqkSNjKPCJHV7qg0OU+XblB/YiFWUyghmwi+CET+ZcLK5gGQryEGo26OEmEZMs/Cx20DeK957T82cMXhIaE3NDpT0+n6b+oSuWn8XOmXoE7qtZicO239am8x7b0CbR0LbCt50vFr/yThEPfnSiZNltOuRaB0ToV+i0qqWKlDvLUJIqFxLGXKNk7puVE3RE0JM6i9UckkzgCQqKafLvVDLB+i6gu0QwM3bab2vRc0lg3ptVfU+g1pPyhC0k5hNlEndHkPYnYtUAmvTO76V6xAZ8eMkT9wei9SzBBbZLyOrMfJ6pu8QUWsglySE2dGJjlF/n+2vY1K2IqG5hLav9HVwcHBwcHBwcMg5tP2nRRvBGwllgiPVE5R96ajrqgXu1GOUKPDpXsP0/mPrqelPtZWn2qGBRq0sIqvkOCKsYQFZl0Y4yU1llY73Qq1DpesGVIE3ty6tsWWRWm17sjE9Vb2Ltag28P4OqBBEFN0p20czqvewU+8MQaPDmUBMvduk2r1gaXY16pkw/mdZFOT5QQ6CRT3ZuqpJPo2Osu2pPqFq73gFah2nEglvFzWh2REm5RJyloTuofNhi+spYN2WIuzBQMBrhrRl6ovAJAkhlERI5GFkY8ekpkTDVv0zQqz2hLeydxJvIMHAntV1yvWqlL+LVluyJUzw80Tr3Fi60waqbRNd13Ns7AXJaiRtPggZ1P2q67L6T4aQ6zqpaJPMj6WeFVkDJZyqQTyBmgLnNabab5SRS7V2lB6G+iGiipUUUmiTjm+HsLFocuDIWRLqjYSybo9BS/V2r6spKtmDJhkROyvRh6gGqo43Qgs3tfaNkVwGlbzYkCG5dtS/iNY6iurrCHvxiTWmIIIbH8FkEVmqyBejnqqqnMHGE5SB1iFa1JhSwhmictum44uN0lwlFmrE1Kb+k8DQGsvg5EqpiWTqcxVqBztGQtl+VbKqCpjklpFGrXdlH3shWjTZ+NN678X2UBOKzCn03ku5s4O2v6rtBMweMEiktBHMrD5sHwJeCC8W+KsI2oaNRZTYu4Ga6zPCSY5VJab0Jc+it+xBKEZHRci93dk5ENajqnc6X3JOVAW5TbtLhrDH2SDMSKVFn/hIHmmwwbIQ6gtd/WBTW4iKEU7aOp1khNSoXwP5UPYSR0YaGbFi4xjYutRmKRuKdAJe0sUU/WKUnyHoR5b67mH3YntMxyOAMKmN7ou9BY6E/ge7yDuf2SvubPDfUJ3Iu5t5jHbK9z+4aJ940XKDPQhj7IXB1OG0tpN1QxLM6m2iaoysEosmZsVDOb1KGqmqWnxRUzItkuQQIdd/MqhOBYWEDKmNA+h+g1vRUKh1l2HXNXp/txb2RNRmyabXe9h1ohYqdfYss4n6sciib12L8gZK6BipZeVRbdTRh9Z/MlClvhjIaG2fUIb2SEJdTWjocOp4BwcHBwcHBweHrCNnI6HeAmMmJCq0yF5wYVK4X+j8K5ZE39QvSPbV6o002QiELIzkTS2pV2VQo3Rsv6oqX01vi1Ajmr70u6xwt/jWDFuYpbbKVDshqQjbDcAbbVSjR6rXp02/9rAh7kNNqasd4dS0ujcjpNZhqs9U1ZszQerg2XylaG4GsFLbM6hescp9wTIErOHJXgInTAofOUtCaxNIKwtkIiQrL3DaY96/EybT4AImRpBIqpRBfXm19gvN6qVvIfJQRVI2qm+W3mdqNwsETr+rim/V+F22aNI6EtGuR+wHxIzzbcilDby/FTF9HomSY5V7wou/WRuTfDGVL1vLEYKkkktKOAWiR+3sxH3agB4DI9whm/Dz7WkuAtxhhexE/VBSAh5UByD0k28H6XgTIB2fWz0pM0fOktD8KA/ANAXTi3Byqdp6+JepD0KrB6aNSleppVPZOiMq1N5JJD50LiLhZJG2XWK7UEo4xf0y2JBfL1QySLfPWkBmwS5LbZ+p1vbSmlByP1IRTkBilkesjdTfnVqfnY2XsCpqEjsB2dgRMbAoondd2n0o4r+f1Ain/Iy2iKLKxFy05VNV9LJYSfUYDbNOVHGgyDJcx6TwkbMkVDGrV6F+GdGHdMhKzsDCCkAnK17YtIpUYRORY0SXkUaV6KqEk0YWtZuFKt+VF4t6jm3siWg7TqaOt6hnoWl7C5shNhdVCc6gCJHY9m1+nzZRT9bHPuyInGhBJ2+PzC8pRGBV03iVNLLjYqRWfZbL5FcsW2DgHwSiVyojTbQZg0BMVUeHvSRF7+UN6joOzSNnSajXrF5uNMN+n+K6LB2fJz5UopF2fKmoIl385akqehoxFV9wVhZNqrLe4hwQUKslL8Im/7R3vAWRtFHMBzWNB/QoKovIqAp0729ZPVabWk/qfGHTw5197BI1O6311AhSjDy3mKuHSlaV/bJ5MGsnWVVOIJNrMXjA5mxjAyU27Gumwx7boHjfekknuz/D9BzNMlwkNHy0Y2bTuvB+0diY0IcNlk4KvSaUIWg63qrWU1yXEQsmJFIFR2oNp9z5SbSpshH6eOfMtp8nXmsazRTJv0oQCay6HqnR25DnLEU01etqE82kaXHm9al2TCK7pd7DbJFGVhlsTNjZr9Eb9eNRT5LGp1ZRbU909gR2jmOkjIZGg0Wzegq1ladSK80+/pTtt4OaUGfRFD7a9y/OwcHBwcHBwcHhK4mcjYQmPZFQZlZfLHYYbCDRUSZWYstYBEBNudD6INraT0t/0Fae3nFqtNRCoUy/2GkXJTHNbiOuUVO7DKxcIExbJdVwnh4/u3fIPtQ2m2o6WhUN2aTjaZRXFCEFtEuihvM2rhRy9kKs1SPRXLVfu9xZSGyoETa8z1AWeOKtN8k5oc9jrb6SRR/ZM1oVJnGRlLYuS9jRWk+CiFF/e+TeSwSs7WwHUU4FriY0fOQsCfWC1YSqtk0q1HXVPvF8J4LVBcAVlQkiuPGmTlQPOYt2lzRFRFPFoieo3O5RrBNlULsj2dgHeY/DxkVAJX5W+1DLDAiBUx0ImG0TnbN/kdyLPaiVkdo3m4HWcJJjbSPvUJYWVv0quT+nRvSUdDkvXdLG8bmJgiPZCcC/zEaEFLbvKBcrWaTtfTvYOwgng6sJDR+OhP4H7GvFRoQUI73jmYF9VqD05QU4afC+NNmXrmyoKlqdxMiLS42ENlhYJdG6U7HGVO2dHqZDgA3hpNtnNaZa7Si7ZhTUUkkUerHrY1NPqULyPhQFGAwqWRV9GQ2t/dPmwsgli2Yy4qM+3lSSSNcVWn5auYsQqBZN6j5YDWfC+J9b7HwqFlVAM3Wx5DiYcIwX3qrPEKGOk93b6juqjeFqQsNHzpLQqEcdr4qb1Ruq3uLrR02J0SJ1Yh2iRnekVA972TKzcbX7EENQqyiAExpGJBkpUa2X1Gireg4oQWIRQ8+5pxE/C5GYShAJqBE2e3HZkEGWyreJBNqkvL3HoRJJG09QMg9GOBnkTAqBmj5WhT4qqWPG9HQuXnW8hS0Ugw2BVTtBhb29GIuaE7APDJmYMnjvURa0oCUvglipnajjM+2A5CKhe0bOktDaBqRZ0jAnGi6WZnVP/nGsTpQp62l0lApURX9J8lKKqFHPpPBCVx+g1K8zePqc/Yxp2l4lqzb2RjalAbJxvLAPtV6TQVXHi5AjoQxqnWzQ1oGAXeqdfbB5o1k26XMG0XqJEX3+cerfXn1Sa85AP4otUupW3qHCczDstqC6N6d/vzRyGTJZZ5FV9XjzRLIq2zaFaVbfDjsmuUho+Gj7TwsHBwcHBwcHB4ecQ85GQr1gLTyZWElVuqmBNhbJUNu1qcso1K/YoCbCsoJcNNGuF1Plqkdk2F2PbDoQqRHToJFPVR1P/VTJcdn0omdQ62ltBFFhq9K9aKPuSFQgRLv+iJkUMeqpRjhtBDIqvHNmXZVkxxHx+JWygN0L1S5KWnSQdnki42jNLlkmp+PpZAQnCRrRJ8eaaPsopwInTAofOUtCvR2T2HuQEU61sxKDejPSByEROtG0FnvvyYX1QppMNS1WQQki6RSiipVU0khJjlCHCegm+TY94RUCZ0PKqDpeI5w09a7OxaaEgNUeq73e1VSeql73wqYhBDOXp+IazTVDbRXJoBJOG6GTmmaWtycQTLZPJgbKi/ivhU35gAo18MDT7GpHKz+ohkCFcg+wdwNLxyvvkHaQjncWTeEjZ0mo92ZiPMWOcPqXRQmRpHWNzOONLGS1QBRiTajqQRcqVGsjFVRwpNW+WZFG1d7JpuuPb7/i9m3qX5nhICWhNjW2ZHs2ZNXGGoltLkY6k/n8c0UfUpUg0+5IfH6+YTYEkfZT16KDKtT+7EG7F/Eorbau7Gsqt+P0I6YGBcQIrM26tD6VBEvkFqfKebFwfmhr7K4JzTQS2kqT+YogZ0moNxJquy0vbOwgGaxsR9R0ueKHaHZp+2SQ+8T7IxShe4cyAteglhCIVkY2Tx+bfupBt89ASC1VwtN1xYipTSmDjRJedIiQUu1qCQCDaiRvIfLhaWb/sgQjiCGTAR4dVVX+zNTdKxLTSC4PAKjG/MF/i/pz27+IB0u03w9V0ZNIqNpClT9DvA4eYiRU+X2KYr3WhBMmhY+2v6pthDqT/rxhIXOmjmdgqvc4tXfSHrRWdV4q+VW/RoNCjWSJpDFCSLPc95jBytRejY6qrgSi2b8yP9miSlyXQLZeUs8nrR8mcwlqGg/we5tZI7F9sJdmfmHL+7SZL4GNwls1Ulf3q9ZYsv0y0qT2omf3HkurhwkWkVWN+VW1vU3trFrywOtYgxNi+rvw3o9qBzJlmU3ZV0hw6fjw4dTxDg4ODg4ODg4OWUfORkK9YCl1m3Q97x3vH0drocQidQra7cLia1TZFovusDpMNeLHoEZHxXX1tD0BiyyyPvG0A5NqYC/4ZIbdolPthETXFe9P1RDfpuuRWItJU+9hqujVukES0beJ7tGsSdh1naLXJT02oo6WI3KmZS9SNdKoztewqKeFOl6t69RbmWptO2kZAIuYqqUG7LmvdEzai5FAgEhoq8zkqwNHQv8DlnrnrTxZ2ij4fulDT7VJoYIGthP1AS8QU1VEYUNKmH0Ss21SW3kygqhC7Upk0zGJgdVJBm35qVo05anpc9GiShV1WYiGZPGPhTG9VGrA9immYnmdo2anIxup05pyrexHVVqr6XMbCyluSOTZp2rHJIqh2DHYCKnU9Dlt5cn6xFsZ7IeMoM1N9hKYADWhNlVjuYCcJaFBCowBoDbh/8EX5bGC7za689huwyzoVmvmVKKiquPlTkiicl2FTfSNQVbbC/uw6VfPrJfU/tBqNyM1Yqr2kqYK9OAtaeUPKlb/GdQ/VyQHjHDqNkuagInWk1IfSm2/qtLaJgIbtGOSDViggJJB8dyptkhhH4e6PfocUJ+/XzHS6YWrCQ0fOUtC8yNAfpN3hJqOZ2021bQ9G6e++Gk6ycYMm+4k6ItVBHuQif6aslk9g9xSM+TECY3yhviQlhXuokqdLbOxSmIfInSZKBpSxXQht8YMui0qtiGkRCWX7BnAomVq1sRKrEQdEoK3y1ShuISwDw52jvWWmuy4WIRXbAgglh6o9k5W7UfF+4KWiLHjDXptleBG0ExQiHAkNHzkLAn11nawG0Vdpmae1ZvRxsyZ1kPZfJ0GTa/QaJk4jk0jTkgJSdvLKXoGOmeLFlk2dkRBI6bqPFSonZtUSynVcF62N/LfF5EY2QczhFdT+QLUNDsDJ5cWdZMEag2nmlK3sYtSzfTpeUn6f/Nev2R1Hiox5USaNQlgdZhaMwGZSNKPLi0qy6C+a2iwhFpcCdtTMx/tEI6Eho+cJaFeFIZowdgcWCQ09JpQ9uyxaU8Y1HLDQnBEiZSatlfJkFyLwY6NDWNtMEn0Vo1KMnGAko5noI0JLGp21XFsGU2Bi48h1eg9ZO9Qbn6fTmppBpP8tlX7NQYbwZEquOHkUuzWpkI2etfS4IyYKpDFRSJUy7jQfVfFDwxVECUjaPe8hjpt+77nXdub+Tif0PDR9lfVwcHBwcHBwcEhhfvuuw/9+vVDYWEhRo0ahddff32P4xcvXoyBAweisLAQgwcPxnPPPZf2340xuO6669CrVy8UFRVh7NixeP/999PGbN68Geeeey6Ki4vRpUsXXHDBBdi2bVvox9YUORsJ9X7R1JOPUxa5ZIbzbF1mYF+b8C8ryiOREfJly2u/BMPs5mATLfKCpVeYWIl984Qs/KG2TaqRPC0XYHWSqvhHrKdUU1EB+8nL5vJq/aca9WRQTN4B/b5QVe9snJWKPv2fNp2L6DixX7udyCd4St0GNnWiytHalDPRfbK6SbG1shpt5RZSbHvBX9lq+YHcLlW5ZjaZCm/ElL5TsotspeOfeOIJlJeXo6KiAqNGjcL8+fMxbtw4rFmzBj169PCNX7FiBc4++2zMnTsXp556KhYtWoTx48dj1apVGDRoEADgtttuw913341f/epX6N+/P2bNmoVx48bh3XffRWHh7ufyueeei08++QTLly9HfX09Jk+ejAsvvBCLFi3K/CBERIxV25m9DzU1NSgpKcG0ygcQ79ghtbyYlIztS9oedSfv0C5x/4+jU77/x9elwP8QKCbL8qP+ncRjHaRxBdEi3zLq+7arxr+MpElM/U7PGNK2cxf5UmIPH5ZSZ3ZMTEhEvTnJy4wJmFRBlFqHaWO9xBDU3kis/+Ren2SfKrlky+Q6USY4IsvEbkaRfP/9Tl9yqsKdkVA2Fw9UNbvq/8mIT4PR0piqBZJah8jS3TZ1nWq3IVU45VWbq73pVT9Mb80pwM+dmu5m49RuS+r21HQ8616l2lRJH8/sPVBP3iEsRe9Zt6ZmB0r2Ox9bt25FcXFxy/sOEY284daqhSjs5H8X7wm7tu3A1aMvzGjeo0aNwuGHH457770XAJBMJlFWVoZLL70UM2bM8I0/66yzsH37djz77LOpZUcccQSGDRuGiooKGGPQu3dvXHHFFfjJT34CANi6dSt69uyJRx55BBMmTMB7772HQw89FG+88QZGjBgBAFi2bBlOPvlk/Otf/0Lv3r0zOm4VLhL6H8jlddQnlCkl/cuYsp7+4KnyUqv7oV+xaruzoAbhbD1VzW6hmGc1l1xEYQGVmFpZSInjhJs0woQ/dPuigb0NaKSxZUIHwK52VFbHi/WftMlEy3eVHN1jkTaLiKSNd6i8D4u6Tm7Crvpatkx+eVRRbUmriYHUnutqVJGfp9aHul9VYOV717B3j2qNlmh/YiWbmtCamvTATzweRzwe942vq6vDypUrMXPmzNSyaDSKsWPHoqqqiu6jqqoK5eXlacvGjRuHJUuWAADWrl2L6upqjB07NvXfS0pKMGrUKFRVVWHChAmoqqpCly5dUgQUAMaOHYtoNIrXXnsN3/3udzM6bhU5S0KDgpFLRkzZMpaOj8eIj5zajURVy4peilxg5O3fKz5ACvw/Lvq1a9NzXB3HFPMscseistkgl3Sc8Api0UybeTDy34FEEOn1DuilCcipcqp6Z1FKmzS7aLXkjWhSwZGFHRNbV+2sRK2HLKKjNlZO7LnF5kxJHVVft/x7ZEIqWZQjpo/ViDazVFLB9pEX8d/v7NqqVk5MWa82T5C+G1S/33bQF16BTTq+rKwsbfns2bMxZ84c3/hNmzYhkUigZ8+eact79uyJ1atX031UV1fT8dXV1an/3rhsT2O8qf68vDx07do1NaY1kLMkNBrRo59NoX4FMbKaRyKhbA7q17NNWze+wRBrqWwU8ww2Fh4qcaY3hBj1DLs1KINnXVnhTrfFIqGqSp3dtDYRXpE0Mtjcs2JkVSFhnDSGW9cpm8ZbEE4bMNKYBMlWiMRUtpDyjFNJOAMnW6xsQXumUrN68tpVywXUelKGsGtlKbznSs0atYN6TwU2JPSjjz5KS8ezKGguwqnjHRwcHBwcHBxaEcXFxWl/zZHQbt26IRaLYePGjWnLN27ciNLSUrpOaWnpHsc3/t+Wxnz66adp/72hoQGbN29udr9hIGcjoV6obTZtakcZ5Kin3FnJQtkYNDKiftlSIY14Qm2iqLRdqEUXJbV2UlWbB27bKa5XYBHhVHu9s+stiouo4IiIQazS++J+5aikJ7KmRjiZgIkqqMWORMmA880EqnZV3oeQUm9ue4pxfti+jGr3IdpZSe5mFHIXKbq94DXf9H5k94U3u5BoWXC0ez327PXusO1jZgkTQYJkOVtaJxMUFBRg+PDhqKysxPjx4wHsFiZVVlZi6tSpdJ3Ro0ejsrIS06dPTy1bvnw5Ro8eDQDo378/SktLUVlZiWHDhgHYXaP62muv4eKLL05tY8uWLVi5ciWGDx8OAHjxxReRTCYxatSojI4hE+QsCY2h5Xab7J3MxEXqugw2Dx8rhPmDVus6GVmV2ziSB6gqYGIiKUZMWU1o2F2UGOjxsppI4SWinieZIIupd4uaS5lM25jLi7WeQa2RbKySVBW93EGHQLVjUkt8ZGN2uq5/nFzzblpWbqsCITo3Kgr1T4O3UdYU7qrLAcuys9cPpZu0xpbV+5J90FIDdk7Zs1FIq7Mx9N0giJyyjGyZ1ZeXl2PSpEkYMWIERo4cifnz52P79u2YPHkyAGDixIno06cP5s6dCwCYNm0ajj32WNx+++045ZRT8Pjjj+PNN9/EwoULAez+jUyfPh0/+9nPMGDAgJRFU+/evVNE95BDDsG3v/1tTJkyBRUVFaivr8fUqVMxYcKEVlPGAzlMQusNEG1y37OvFVbgzvw/uTrev082TlXH2zz0Ixl+iYUCSnIt/DDrxMiqDdTtMQJnI2AKWu9Kuy+xN5dF73i1X7ta10mWRWJibZTaT16Oemq1k7zLUcsvRB4F01p0Mqi1nuq6NoTTqoUobX0cfHvKs5Fa5Yp1nTZ19jZWTmFDbtFJ3QXYb0r4QGXPCuqQwSKmnm21i0hodnxCzzrrLHz22We47rrrUF1djWHDhmHZsmUpYdH69esRbfK8PvLII7Fo0SJce+21uOaaazBgwAAsWbIk5REKAFdddRW2b9+OCy+8EFu2bMGYMWOwbNmylEcoADz22GOYOnUqTjzxRESjUZx++um4++67Mz+ADJCzPqGXvpDuE9qViHv3KQjuE1pc4P/xFef7XwQlZB/5Uf9LmXmCFsY6+ZYxTztWCI9a4u1Zt8O3yOcTytIrTPXOlrGULfUJJZE7dRmB7B3aRpAinICfJFJvTgtjekZqRSJJfTgZMSVEkpJQNdoq7jcplsc0JP33LXt51yeJ16GwLR59C94Wk89X+4jNBuHkIiRGzIJvzzs/uSSJIMZcGcRAgQ25pN6hop8o9//0/77ZfcyOg6/Lek77F/neD+yZr9o2ed4hNTU7UFL2gzb1Cb365QcQz9AntHbbDtx63JQ2mffegJyNhIaJBhIdZcsY1JcSr4UKrpSkUF78jISqoCEQFhkUSaOaemZoAzU7kIGiXVWqC/u0gtppSK31ZC/qkNP7jHDaED3arczz0lRT+zZ94lXwHvPBMyncWk7LEjGmwjvR+ZcxYsrPqXe9lscAQB7JEOi2UKJyn9yzMmmkKf+28R3lPqHCFtnziJRUULLqJaYhuzkEQTKZefyiHcU72iXaPr7t4ODg4ODg4OCQc3CR0P9AVcczMP9PtozBJnUkb0/9glS+stW6HLZPFrncIXZWsojwydFH0m1IrlaxqQlVj817/my6Hqn1nyqs/DotBEyyh6MmOGK+jqxtpTfSRKOlYr2mWhPK9qH2A7dJgTMxJouO6ho+bcf1rPSYrKo8u23Kx22io1FRMU9bmYqepbzemd3vmshQNqunDiie3y1r0Rn0fZRDNaG5BEdCswhm28ReDjYXhRbR23QLUaAqlKmgRTRNZ7WjIcOwzkqikEiu62xtsHKEQlJzyciqqmZXXwY225OJKVO9M7W5/4Vpk6JXyl6sjOnleTDFPHvOaKlyLtBk89PG8TbHZCoiNFKrNQVRyaXcHplAtnISygwAIC9KaqpZfarYi17tO0/T8ez36C3XUt8pimVgO0jHJxCAhLbKTL46yFkSqnRMUr+eWY2TGlm1aevGYOUTSnsBh/jDD7s4xsZ3lEbVxLqnsAmnOmclyqkq5lXix8aJrTIjeaLgSF1GQDsVWXQHYuM4gfV6U2rz4PWlwUVD7COWw38P8BpOPxjh1F/E2n5VX2WlnpQdViFZkZ27GFWBa1ZONlkttYZTbXlK27mqr3v6PmOtNgN6Tcu/xZa9eLONbFk05RJyloR6bybVoomBPZDVBzyDjaBBhkpCvLBJ7dN0r5qOJ+dTFSGJIp9I2H3i1XXVFqI+s3oyhirmwyV+FOo+bOZC22xq/dnVtpoqvKQh9N8n3ScjyMEjl4z41Sb85131Rmb7YB/jbL9sH2x7QaGTdT94xDT4XFRxEW1iIDoVqEp4dZkM7zPU4jnjzeBlpe1oC3Dp+PDRLkjofffdh3nz5qG6uhpDhw7FPffcg5EjR9KxDzzwAH7961/jnXfeAQAMHz4cN998c7Pjm4M3Elog9nVnYF/nvD5KS4mpNhwMVr3j20LGR1XqIjFlhIsRU9b1qK0IJx0n2ip5o5zqMdiQQTEtLpd8qH6iTPUuptlVKBHO5vbr65gUsjqeER/uUexfVyVv6jj2ElXT8ax2VJ2LSlYVMLLOCLJqdM+uI8tqqdebRRpZNDMv4s9C0I8p+X2hOUnI8LmpBP99tsdIqCOh4aPNPy2eeOIJlJeXY/bs2Vi1ahWGDh2KcePG+XqYNuLll1/G2WefjZdeeglVVVUoKyvDt771LWzYsCHLM3dwcHBwcHBwcAiKNo+E3nHHHZgyZUqqHVVFRQWWLl2Khx56CDNmzPCNf+yxx9L+/eCDD+LJJ59EZWUlJk6cGHge/Gvfv4yLi7R1VfDUYcifU2Lv+EgsPdpoVJ9QFkFLiD2DWV1jMuR6VZt9qOlzG9AIsee+KCS1mVzNoO2TGtOz7iYENl6fZB9qhIbXcGpKddqLndVnBozAsN8si4zptZnaftXIpbqPevK7CBqRbA7sHLTUVrm5cez41XQ/G8fmQa8teW5zZw6WZhdLHsg+1BadhgmTwvaVVjomib+niKfxSiTa+uLUlpAMEAl1NaF7RpuS0Lq6OqxcuRIzZ85MLYtGoxg7diyqqqqkbezYsQP19fXo2rUr/e+1tbWora1N/bumpgaAVhPKUuU8DRVu/aeuvLQgYTQtSh5c3o5Jquqdkbds1PTYiJUY1JQ/nUvAWk+gGZLsOafs6UZspijUDkestizGui1pYiU6jqYT1WYPNmTVnypkBFZJ+bP1GNQXEieDwVPbdHssBS7tIYO5iNvLJx/3bH5sg96rw4z0GdTyKAbV9k1+llu096Trks55oYM+95Xe8aIwyWONZnKod3wuoU1J6KZNm5BIJFL9UBvRs2dPrF69WtrG1Vdfjd69e2Ps2LH0v8+dOxfXX3+9b7m3JpQX0EtToGBf1KpdCbdZykL/d6UmkD0/bGp1VFLG6jrlmksWGhLtorJBONV+8l6BlWqzxIgfPX6yOUY41Y8JuduSfxHrzc4V6Jr/p03tKIuYerdno1y3IZysMxv/oPaDXUX1l6x+eKuqd0o4A4JHqrSaU3buWAcmuXc8PaOafVLYHtIqrIRJvo1pQQuTqPWP8310u5rQryLaPB1vg1tuuQWPP/44Xn75ZRQWksgOgJkzZ6K8vDz175qaGpSVlUnbV79gbOxFqHuO6NMmQ/1ipX1+hR8+JSUWDwyV+DHYPKj0Cx58e2pqXImOWrW7DFmlTpXrGnibSU3UIyvhqbI8ONFVoCrXVcLJlOs2r2VGYFmTDTYubGN6/lNR/U7TwSsFggcZ2PVhQlY1TkDbe7LzJHp4MgLLfj8xEh2VyS87p6w0SxEikWdUhBHTJCGmbYxkMoJkhs43mY7PNbQpCe3WrRtisRg2btyYtnzjxo0oLS3d47o///nPccstt+CFF17AkCFDmh0Xj8cRj/v9CgsiQEELfCCfBreYYtM/To5GiCkX1VSYP0AIqVOjWV6ywuo61RQ9A5X3inWYDKranm5P3K8aRQ27dlS5Zup5V1XqdF2R/FrUerIXtdoxiMHGO1Qhv6oKnBEaRvJUqKpyNXK5s4EZmrORmnpfrR1V3UTosfnWY3sIbq5Pt0YYJ+00hOBpcbXEgxHYWNT/zuPNE8hxUDZN3jXKh2cDIaqqv6hvWdtHQpOJCJLib73pOg7No01JaEFBAYYPH47KykqMHz8eAJBMJlFZWYmpU6c2u95tt92Gm266CX/4wx8wYsSIQPv2htXVIngG9jCzaQPKoEZjeCrfwsBe8X1Tje8Z1NaT/I3BNqhtj86FjGNRWWoILx4vOw5Wd8rgPadqVyExLU5tluj2xPQ+6zoo2hbZdBaysUbiUdmWCZdK8tT0uZ4WZ8tUg3j/umqEUxX1qNZI6s9baZbJU+qaoJQ53bOILO0gJBNTTZjE7JjYPnT7Pu03xcZFg5aDtQNbpTDhIqHho83T8eXl5Zg0aRJGjBiBkSNHYv78+di+fXtKLT9x4kT06dMHc+fOBQDceuutuO6667Bo0SL069cP1dXVAIBOnTqhU6dObXYcDg4ODg4ODg4OOtqchJ511ln47LPPcN1116G6uhrDhg3DsmXLUmKl9evXI9ok4nP//fejrq4OZ5xxRtp2Zs+ejTlz5mRz6hmDfXmrX9RWCLMTDrVZYj2ESW1Q2F/F2ZAdqlFPG2GSul/ftoJHPa22x0DU9klqBq6lxVXBEYNaO6qmO5nAyBsJVG2R1KinjeMGjwSyfWjbY3Nh0UG9S1xwk3xV+R4mVPEoDKupJovEbkastjkPWokLE/blkRS9OheKEN1OvMb07RUuEho+2pyEAsDUqVObTb+//PLLaf9et25dKPvMj+3+awTv/07Ws6hxYttjhLPNlJKMcDDVohdM5ERrKTOf0h5h82a1AVWuW/STt+ntrqxH92kjTFJb8QUXEiWSpP6TdkzSLJVs0vFKfbdqi6QSTu6bqaW7bbooMahpa5byttHwMahiKi/Y+VQJrfrM5x9O5JqRVsUs9a7uQ1Xb0+3ZlFIxKMIkNRjhHdcOUvuOhIaPdkFC2wK7EkDau4o8A3aR9xZTqBbl+X8c7KHHlqmG1mGqdgFwwqn4sKlqabotcb6s17usehesjZpd1cKYXp0zI6tU1COo3Jn1EvuAyfNHQGhdpyo4ItdbFRdR+yRCONXt6cIk/34biO0XI3WM+ChCJGY7pBJO3gLTf33Yh61qCWNjs8T2oZraq+If9jOjHwSeS8trXTWxFiP6NgSWBhnEgIIakVTbdqp+onS/qjreC1WEpIxrByTUJDMXJhlHQveInCWhijq+kPAF9lBl6njuOyo+uKinHeknz76A1YgpUy2yuXj2K8cZ1U4ZqieoTBDF7TFY7YOsqxrdq7ZNSjcStcORsv3mlhFiytJpqhKeRW0SSf/9aSVWEi3OGPljhEuJjqoiJG6BpM2XkWElMgjoYiC2D1XIqQqi+PaCRVYVBT2gX2u2T05Mg0cQbXq4hz2OO6wEJIA2GTHvcyEbzU5agIuEho+cJaGK6aweUWDrag9uar3EfN9Yqoc+uIg6nr0wWfSyfpd/e94XukouqVqa7HPXTv8yBrWWkkEllyyaqeYT1RQ1+8Wp63rPH/XrtPEhFbseEcgenmKXIpsUPU+pM59QLVKpWCix9yqPcGr3E4t6qhHJhEgabZTwYbt/2BjMK2DXMN+ivpS3AdU+diOUXGvlPIwgqgEKBjWKKhNA7/ZYsEN+hzTs+d9tAEdCw0fbf1o4ODg4ODg4ODjkHHI2EhqLpEcHWEqdwab4ntclaXVzVlDFJWRZ4HQ8bRVJvmRZr/N6TbWst7u0EQ1ZKOEZ1JpaRTDAvD5Zz2i1w5GYjlfrmOXovbg9hgS5p9h+1ajkDmLWztDgWZelrNWoJ1fWs2XaPthzRleuBwdPvfvnEhPdP9iUeSQ0fSCrRlC7L8XF2tF68vNk0VEWuYxFtNcu7RpGIv9UWR8hynrqO6pl3dQ2pT5Q5xStFKw9tu1MJjOfRjuYdrtGzpLQeB5Q2MLRs5pRteg9HtPuPNWOiaVX1C5KFGKLTsk6g5EXNXXSVsXmsoreok6UHRsj57RWViSJypiQFbCqtRHv4c7U7P5ltF+72I5TJXCU/PkXycp3L1Tip9Ymqp2VmEAo7PS52iFK/TazqcX0lhrwY9UIslpjy8gqg9xkRBUNWdj3URN6UX9AmVTAZwh9p7QD0ZEC1zEpfOQsCa1PAtEm971a/8leBGyZGrVQQfthU3JBVlZJiFJPyMiRKHKSt0eOVY5mqs8ydXvqOFXUw8DOgVCfSVXvtA6VbEtss5kAizT6fywNJvg9oHZMYuNUZTn7fdNe7OQ3qhAuti0Gm97xNlBbaqrthnlNpH8cfV6S7TG1OduekrFS1ezsUZEvk2aN1LKgABXnyUp4clLEOVM/Xkam1e8VhTiSYATrzEbfZe1QHe9qQsNHzpLQ/GjL6ng16qmqUdk4luZgKRKqhGd2POxLmT24VDGRLyViURxuY2CvRi7VFD1N21v0Trfx0xSjqJGY0N5TsXZqDuzWodFxlq4jZIMQU+bhyaCm4xkYIayVPxTZXAQ7JvE2tukTb0NW1XIE9ozidlEaqbWBGgn1PpNVswl+/TX3EwYapRVb19L3APlNqZFQRn7V94oM1X7JOyRBMintgGAqcCQ0fOQsCfWq4xkh1ZWi/nG8/tM/jj+QyMORRQcJaAcmac1m4CVIJNUnp4BZRI5FUaNkXUbUmN2R2qVIreukPjb+RRRq/afa0UioCaWg/p/+ayEbulPTeBIxJfeiWuemKtx5NNO3SAbbnlrrqGzLpksRg5rBkf1Eqa2Utj21TpJFEWvJraebyaf/WyX6eiRYe5arpQcq2G9AdVNhYO8V+v5h6Xj6nCb79dwEhjha6M0u2h8SiQiiGabXmVuFw5fYO668g4ODg4ODg4PDVwo5Gwn1quPVwngbqOpJXjPEVMpiTSiDmP6Q1PE2hsSygb2oeqciKSaXFYvvae/4gB2OmkPQcRZen2rNJYtw8g5HqghJa6nJop4sIqkawu8UVe9Ba8PZPJhARu2YpEIuKRCD/ExazqOymkuIas4fs/Ds9F4ztV5V9Rzl9bTa9myU5mG3b1Yjq3wyLOrp95X2r+YvITJ12/0DqVDW6xPa9v3lkyZAOt5CC5ILyFkSqoByF4u6J/VhxiAXrovpfdUWyLCUtxeqcpLVk6qpfEby2Dj2MFO/Jtg+GGzSSSpZDdonXrVZsmh+ILfoFLenCo5Usua1TwL4LaASU6XeUzWXV8fxLkr+/arpfbVmNSraEak2Syxtz56haqcmuq5nzirhVO2tGIKa5gO64TyD3DteVL3LpDZozaa6HnvetT3n9MEEqAl1bTv3jJwlodFI+oOPkUHWtpNvi1lf+EEjobS3sFh8rhaaMzEII4SMwPm2JUYuE+LDR9knoBeOyAImsky1SlIhEk4qOFKirWLveHZ9jBiRtFnG+8Sz/foWNaNS9x8HI5xqxFRVqgetMVVrQtmx7hLLnRnU+k/1uNi67HlJazEpafQPU9X7QaGKkNijV71mqg80FQ2RHedZkEb1QzEWIc8e6nAVTAsg2zFJoti2Fy85YVL4yFkS6oX64OY94VthQh6ofX+tENSHUlW9i33IKdQvamZ+r87FRuEetgiJQemlzFLvFubyNib0VKxE1mUG8bonpmapZJMSUyKwbJ+qoMdGXKRGOFWoc+HRUXVdjXCycxon39j1Hpao9qanbVvJfcxiEarwlEGPeobtNsBETYx1q44lwnO/HRDHMOF8QsOHI6H/Ae3BbOEdyp5cqgefCloTynikjYelsp6aKmdg6zIlpo16MmwPTzrOgtSqNabeyCfZPvWTpQSR1Guy3uwWHp6qkTz7/dSLtYTq9nSfUP84JbJaJ97uKmlk860j73xGrthcmPsHG8eyPzZuA6wYldtAqZFF/y68xJRtK584bqiRVvW5zcgvr/lnN4FW889+33IHJrFbWUR9dgd9rqpBC9c7PieQsyQ0FjFpUU32MOfWS/5l7KFq06GEptktrDlkBK37sYmEMtjUV6pf3qr4Sd2v+ECm9icMShRVNPpXU+UqWeVCIm0cI3Q0IiWm2RmRVAkng+qn6d2HStQY8VOJJNvHLjHaysDOEysDUJ+DlGCzxAn1GPWP4/Xy/kkrxD5KIn5qxLTAQqzFoNaEqsvU2usoa9vJ6g/Ys4wFBoIazLPnpyByag/Y3bYzUxLaSpP5isBZNDk4ODg4ODg47IXYvHkzzj33XBQXF6NLly644IILsG3btj2us2vXLlxyySXYd9990alTJ5x++unYuHFj2pj169fjlFNOQYcOHdCjRw9ceeWVaGj4MuD08ssvIxKJ+P6qq6szmn/ORkK9UNNkvEg93LnQKJUYpaTpFQsFesSTAjERUZjEbi2bNI8KNb0fZVZG4lzCTu+rdbGefagm73YpdRZl0dZlUUWmSGfiIlW5zrv5BBcEqSIh5Te/U1T3ysp1MWIadu2oTakSm3NLner2PBdWw9jyPlmUkpVtxPO1E8Xu7aK84G4QPCIpTYUr4Uk9mNq20yp5rDzj1aiqEyZJOPfcc/HJJ59g+fLlqK+vx+TJk3HhhRdi0aJFza5z+eWXY+nSpVi8eDFKSkowdepU/Pd//zf+9Kc/AQASiQROOeUUlJaWYsWKFfjkk08wceJE5Ofn4+abb07b1po1a1BcXJz6d48ePTKaf86SUK9PKKuFUu1F1Noi1SfUpgOGXaE5I7+eB6aqbFRSNZnMTYVaBqCOU6ySMllXFTAJ9wBXyoYrOFJJrapIZ7WeqqUOI6ZKX/fd+9AETGoa3EvCVDLICK0KVeGukkY7Mhh8e2otKrunFHsn7hOq1YkysNIqG+W+2npTFSaxcar4ic6FtdUUuiMB5H2h1nFSdXzAmpdWRHsTJr333ntYtmwZ3njjDYwYMQIAcM899+Dkk0/Gz3/+c/Tu3du3ztatW/HLX/4SixYtwgknnAAAePjhh3HIIYfg//v//j8cccQReP755/Huu+/ihRdeQM+ePTFs2DDceOONuPrqqzFnzhwUFHwZxOnRowe6dOkS+BhyloR623aqUOutVMU8JZKixxsFO6bW7surWjSJ3qT0gaQSOrUGySbCKZLGSCxO1hX3QeZsPC8MFgml/dpFwZENMaWEU6yvZOp4Hs3TflSMXLL5qYSTjfO+f1VhEoPNu5U22QiZrLJljDSqRFzujivbJaX/O8hz/cttkfpfOlITTdmYy9uA/UbzWPZHtWOSxUTCD0ENULRD2JjV19TUpC2Px+OIx8k7IgNUVVWhS5cuKQIKAGPHjkU0GsVrr72G7373u751Vq5cifr6eowdOza1bODAgdh///1RVVWFI444AlVVVRg8eDB69uyZGjNu3DhcfPHF+Nvf/oZvfOMbqeXDhg1DbW0tBg0ahDlz5uCoo47K6BhyloQWRNO/3NWoJ4MaHW0zcVFr/8DVaKYN4WTjYiGq/puDqI6nXp90YHDFvJcQMnLJXno0bW+RymfdjFgtg03pCrdZYuPYfjXCqaao2bigpFONhLKoog254n3i/YuCEr/m1lWjnuy8dGD6FcHeKeyOTAw26XhWMsWz8aoq38LySX1fBM0cJS2CB95nfjsgqjbp+LKysrTls2fPxpw5c6zmU11d7Ut/5+XloWvXrs3WZlZXV6OgoMAXvezZs2dqnerq6jQC2vjfG/8bAPTq1QsVFRUYMWIEamtr8eCDD+K4447Da6+9hsMOO0w+hpwloXnRlrsVqSkXRT2bCfRUCvsxB9+vVpfDOhIxNaXYHcnG3ils2yZ2HKqBvXpsVl2OWk5tMZU6g2r/wkgoI36MNO6k47R11cgqW6bWcLLfqEo4g0bfWE0jS/fuJOeEjWPbSzCSI0I1plfJJQM7x2wfbJxCTOOEcFLSSOamtjxVW4PycqvWt+qLsqgng/oMVYmpN2LK3hftwGopKEyAdLz5z/iPPvoorXZyT1HQGTNm4NZbb93jdt97772M5hE2Dj74YBx88MGpfx955JH44IMPcOedd+LRRx+Vt5OzJNTBwcHBwcHBIRsoLi5OI6F7whVXXIHzzz9/j2MOOOAAlJaW4tNPP01b3tDQgM2bN6O0tJSuV1pairq6OmzZsiUtGrpx48bUOqWlpXj99dfT1mtUzze3XQAYOXIk/vjHP+5x3l7kLAlNCjWhag2a2peYgUakmHqStl5kKZJWvqRqXScr+qcRU9GYXlWuy6l8tYc7cQxgqXe1dpTtN+aPWigKdJuUuo3giEYQaUROW5eby2ueoCzqySJoNlFPZXtsvmpXIXaObcZx43f/ODltIkYu2bVQBZ8q+DPbUytNIsGsZErt3ETlf2I5AkvHq2JUmxadbBm/BVq5fEtty8zQTtXxkSyo47t3747u3bu3OG706NHYsmULVq5cieHDhwMAXnzxRSSTSYwaNYquM3z4cOTn56OyshKnn346gN0K9/Xr12P06NGp7d5000349NNPU+n+5cuXo7i4GIceemiz83n77bfRq1evjI41Z0loNER1PEM8xh4+/tOtGhLLKnr1AaIWmntJHRMcqftkDxEb5bq6DwbZUik8NXtz41SzaS9JpL3ZRdN4u9Q7sVkS1fE2aXZG6lQiqdZi8u21/BKxEf4wyO0zyXHlk2cZI0h6BzfSjEMkYQz0vJDj6Mh6UQjlDOy48izqlNgTRe3wFDZshKyyMT1D0NpRVeS0l6jjkTSZz6MV533IIYfg29/+NqZMmYKKigrU19dj6tSpmDBhQkoZv2HDBpx44on49a9/jZEjR6KkpAQXXHABysvL0bVrVxQXF+PSSy/F6NGjccQRRwAAvvWtb+HQQw/Feeedh9tuuw3V1dW49tprcckll6TKCObPn4/+/fvj61//Onbt2oUHH3wQL774Ip5//vmMjiFnSWi+UBPKiCT7Uo6LwiRWC2TzBcy84Gh0QyWcCplU63lkYirWjtqARkdZD0TSUYTV54atwCdQvASZEl7dVoIcv/qsZL6ePBKoRULZC52TVf+6aoRTJYmMcCoCJkYGbZpyqVDV8ZRJEbT0TNzTfhnULkqqEEuR/7HaWbXDEbuPedxBqwm14R9U1MQinKx9L1PHk45JNrZ5hggjJdCOTBbeZVlENGkQzfCimlYmz4899himTp2KE088EdFoFKeffjruvvvu1H+vr6/HmjVrsGPHjtSyO++8MzW2trYW48aNw4IFC1L/PRaL4dlnn8XFF1+M0aNHo2PHjpg0aRJuuOGG1Ji6ujpcccUV2LBhAzp06IAhQ4bghRdewPHHH5/R/COGSvW+uqipqUFJSQnmv74QRZ06pJbvQ36fRXn+H2P3Iv/Luzjf/wNi63bM9z9C86OFvmXsYVEQ9ZfRx6L+7cWMSLh2kY4KzMy4fmfL22IPMjXNbiNCUu2OGGSbJfLay/dfM7pfb693AMjzr5sggiOeGk9vbdeQJHZMJBLKyGpdwr+MRTgZkayp9xNz1cOTbW97vX9dRkC2k1uPkRwW9dR7u4vlB8I7U92nai7PCKLq18nGsf0WimEJFvUsIuFRNo5lndic6fzIut7sFBMmsaAAW9Yl7r+wLML5/7d3tjFSlecbv2dmd2agXcAGYcHSElDrC1ojBLKooZpNUHyp/aCkGorGahshETdqKWjXSgv4GhpEie9+0K61fzFNISjdSoyKbYqQmII0lkVs424kqbDdt9mZef4fDFvnnGvZa86ZPTPMXL/EDxyfc84z88w+5z73y3UjxwNyWqTBh0iAvQft+YmYf+9BMkvwXPBsQE6LGNIERXs3aKvpcgP+cZnewn8jBwhq0dnX5z/WXziPY/8dsAkXP2hHjx6lcytLxXG7YfaqLZZIf62oc3P9PbZ77Q/KMu+TgZr1hHolmlCYHW1maFNFx1hheuzN5Eig/M8wOT5Bzw0TPq8g6FxPNv8Tyixx73zImPSG31FIHYb2yfdMZKigCnfkpWS9QDj/0z+OzfVkDU42vxueC+7hnTOb/wk9weQ45M1kjUsElPclXRI4vM+NY73SCDSOeYghQxLNDeeJouuBinQU7Q4hTI88oSwoVSeG9ig2Z5OphEcgpwUMxxM/HjbfYxSJBfCE5ishjaCCOfksBSGEEEIIcdJTs57QoB2TwhDmzZYG6rKxgvDget5cRxSqKXWiW6lD7+w9ECikjiDnh34DSHSe0fFkC5PQMeTxYXuuswU3rCZoGA8nnktwr2c/6aTxV8dz10deSnYcCyvCz4byYedacs6ogInNYw36HaCcUOSBZ3WgWfWTBPCOssL0iDDdllBhEpoLUl2hNZ7RsaAV7WicN0+0AvJG4654T2i8tjIei6ZmjVCmYxILG+phQ++4JzySaCJlOGDcDd2Y2FTYohw2RB9GwiNM/idb9U5fj7tHnszjZCSaco7NJQVC8iCHEz0u2Ep4dD02zM4al9hY5bojoWccOz/mZZWXYxr5WmZmqOSDFYNHoJAyepbCzkV0sRJaC25vZN/P0XfgNQjR7x2lVpW6kAhKaIVYMwQOs3MLhPvEhwjHM41LYL0AW8VXedXxsbyzWJHzKHZ8rVGzRqjXExqFV5TN+8EGJ9p8QljO6KnM5H+yOqFoc0NtNskNNEylOb5ciDabyOgG3wtueen/XtDvAns0C8/Fkkr+tYCV5sBgQIYk9phyBUdoHCufxBqDyHPJ9oRn/+aZZwjUAgSetiwowoqTLSX7we+ErmYH3yf0XIJj6Dv+Gvnng0DfJ1uYBHNgffm5nE4q7nrEKTXUA8cDuh7W60SKKKwmNfJwcs8QmI6O9rKBXv8xtq4ARcq8QK/nydFFKZ4zixdpLMTL78CtaGrWCPW27RwLqtnR2zMr24QqG1EhERrH6r5h904J5ZgQbHV8GL1SRIh2l/T1SOMSXw8cI/cq9gHkfdigFxjYAhL1byZhDU62OQOr9Yk8nKXW3WSN1UHUx95zDBmSyOBEFNsG8Kv0DQJNzHpUpR34FtijS74kIK1PBC01NcqwUS1W/B5Vx7OOBzy/CB7ZbOSMiXZBDyq3sG6w8EXcZQNKQpWQIBJNxY6vNWrWCE3G8Nv3V2HFh8NpgpLjIugTD3XfghqrbOgdGX7sOPa+bE/4MILzIDWCr4T3f17GY8qG3tEeiMLsyBOKQPfoBUvWRxoqdDierJgP2vXIDBuOTMeTTAY0nWBzDpFHDnnzUGibvAer18lWuLPpS7jLE/gc4L7I2cqkPaTIivwogM8B8MWz1ezIgI0DRwY0dNGLIhSJB8+GLJBjgnqfnHax/54nh6EmI7T0qDpeCCGEEEJETs16QhnQ2z7KBWKrcXGvd24u8Fz0DgFDKZw3E3UHcl4hdag1B95+w1Szs15PshgI5n+yYXY2/xMIzqOqd7atJqyO9/aOR7qeSDQehONRmB1XuAfP9UTzw4LzIORPelGRt5V1PGQG/L93ts8zM47N/2THQe8o6UVFIXoUSkFeT5SfCz2hyNtM7m9skRCTfgG1RKHWJ7on+O5A1XuYansEW3jKPkNgZT0b6WK74jGFSRn0B0oWK1UgKkwqPTVrhAaVaEIGJwovIdi8H/5cMuRNhtQdMIZ8RhjqdhGm/zvcyIL/LGGbTZYQgvOl7uOOwvGDnocI7u7D5XAi2D7xrKQSDoEHL1bCLT+5cf19/nVkQ9nMuazxWkeWgSPDFBmS6DOgY9DQJSsm2Cp/thIctUZF3ZYQbAGTfwxXHY9AK4Y6JmGj1v8doy55dP93sIficWAyQfu/m/Gi895jYSSaMoMn/ncZUDi+9NSsEZpKFBYesUVI6BgrxwSPgQ2EPRd6INlK9aCFQ0FzRM3C9Z0Pox2KjEvo4Qz+koCMRgTsB43ycwljBconASMP54Ry49icS+yl9N8DVsKjZxmZO4qKhti8TnQMeUcR3nNhah10Awb/jaHrZcH16kCRJWuYJpL+c1n5qTCwclmowMp7bpx8EUN7PoItukOw+Z/ImwnzOtnfD/popCIKdEZAxwD5/GHmwRwrtR51AOI5F6A6XkboiahZIzQRK9xI2XacYWBbuNHnsgYhWe1IheNZryfaMNgwOwn0epbaWAXjcqAXO4KVaGL1Pr1GJzYaORF6KNsEjrHySegYAoXZw1yPNRqDFhx9ee7IBiwMgQNjMMxzNAtSLZCXBRmrrBcVGfX1pLGGXhLQnx4qOELPaVZ+yndPMjLFnsuK1SPLD+3bbMFimJbOtJ4zimyx12P2fTL0zrYWLjcSqy89NWuExmOuYHPBG43/POQBYMMmUB+ODNHTwvQI+h7IMvWciww15OGEmqDkuSEMRBq0IUPJJ3AuDAsDQxL1die9ntk8qnwv/J2xfd1LreuJZZaAAQuWlg3bI1BYnDUus9mRZZa+HAdeOsC4hMcwGyTlmEDqMB1SBz8JiwM7pQ78tNnwPvTe0h5DaljJ8f5+6oH3gJVeQiCDsz6Ocrb956I9H8rtAejQO9qk2PA51BgFzghWms8LyvVEx6iOSeX3hFqAnNCTpfK/XNSwEVq4CSHjMgnzfpAxyG0M+BhnmLJakqE6EMGQCBGOZ8PnYUL5YUTtofGLPAX+Q6yXEueJcrJNObDBw2IdzzE2/5OVaGI1PJFhyh7D0ktcSB0eA0Yj47kc7hgyOBmPKRrDGqb1yBgE10um/L87nOtZ2oc19I6CvRF2ZUK/KZilwL3tMR5TtEezefusYYrmwYb3ywaZDuVyIPeSyf80C95qExqr2RP/uwwoJ7T0SKJJCCGEEEJETg17Ql3B2zGbR4TelNlcT1qsHmaph8n/DPEG6fUiBg3LsNcfBrrqnc07rUuCuaDQLqhwB2H2LNkTHveJR8UgI/d2Z+WTwojLlzr03gvcYKxUUhivJ/JKouuxeaLecDkrswTzMKE30z+urxd0XINSTr5DuDU3+R2jiv5B0uuHIkzpELJNrEKC/zz0dwG8ubCAKfj+xhYihlFOwTcOoU4SZj/3AiUDKiCsHhAVJpWemjVC6+KuoPUaCqWgEA7sjgQcymw7TrT5oHONDZGgjhVIgw5dD8GE48NQ4s5F8HqkoYvzbkFhAXyIwPJW6h5YIxEYCJ4HJMr/ZLVDEWFC7+z1YB4maSBmBjijkTU4EWh+QcPxKAwH8zDJoiYWXKzlfyOIh6i8hGFx8lz0wjKWPJmpymc1PHF+P1uExIH2cjYnFGtD+8EtnUGHI2BwwjoAdo9H47xh9Qy4Zw7cE8kvef9+KqDARzqhpadmjdC8ixVsLmF+J3S+JoAWsC/1mzKbY8nsR1H0dQ/T6x2Mc6CiA3k4oa4neQxXwvs3W6jPCb2XheOwuDww3sAaIq8nrlIH8wDPC1Y+CbW37O8LLho/AAwuZIRBiSLSMGW8nGxFOpobMjgzOTAOCs4HBxrOpEg+XB9wLjRWgc2EXliQJijTVpRVOoGyTcDgjAOxeha0LyRA/3dkSCZAkw00Dr48s/3fEUEr4RFhJGYqMic0b/EiPbnFjq81atYITSVcgfeT3bjCVLiz1Y40rNcP6YkikGfVO+cwnlDWCEUbKFsdH8L4RR2OYIU7cHGyhilblY68OYwnNIzMEhIRZzU8UZidrWZnK9JZXU/WaEJzQbAGMXN9aJgCwxwZpuh6qFgJhuihxiiCSw1AoN8UMiRhTSDYBvn+9MQ9ofEa3POAPKth7C1GF9hsmCgMqwnKpmWxWtOomKh/YOQxsBIeXN+7aBXgUVRhUumpWSO0Pu5GbLOGcoHQ2yn7ZsuK1UPPKgqzI1hBePZ6jNEZpmKeNUyhDBT4+YYJ7wOY9plmw1THkw8Wtsrda/yxBicrBM7KJyGvJ1vNjkLqyODE+Z9cOJ41TOnOQkT+I3t91kubRzJD5PUQdHvPOvAyhYx6Uk8U/VaQ4DwyONG5yBnsNTrr0GclQ+pQygmMQ397WDGA7HpEaoJipwUp0QTD5+RzgMXrQECtQqFhinKSJFZfC6g6XgghhBBCRE7NekK9pEJo67Gt1EJVQKIQ9WCIysagc0Fvzqzwe4m9lBBU9Q5gw+ysh5MNx/OtNkcOx6OCI76vOxd6h/mkwJsJPZdAIB4WHIFwNC1CT3pgEXAc6kAElAq88P3a/eNQSJ31cOKuTFyxVgrpjoI1QwVMUGOVLFZiW7IijykD+tvBRabk9cAxVMiKwvG49SappoJSv+CzJoSnjQ69s5G4wuuF6YTkPJuZ99/lIOYCFCZVQEFVJVOzRqi/Y5J/DFsJjyqjw2w09MZAJp/DCkgW71xC9FeHeZ1hROhZYXpSeimMcYkNTi7Xk5VV8h5DD+4e8DNhQ+pob0UGLBsqR4YUKiTiRei5eyDiqNsK+vmALyEW9OEHrJw82BdgsRJZRc8WVyFQ7mgyEWKvQPcg8yTRuNFu5cnaEuiWbHckrJISvK4AwqZDhal6RzDdkFCuJyNMb+b/AVRAWFs5oaWnIozQTZs22cMPP2ydnZ323e9+1zZu3Ghz584ddvyrr75q9913nx06dMjOOOMMe/DBB23RokVF3bOOyAllE83R214c5ALRRUjsJkAmmlM94VnC5HWSBmcM5N3S1fFgHMzhBN8xKkxC1exZ0DseG6bAkCTzP/uAR8rrRWU7HPWAa6Ee7lCOiTQaUa4n8r7B60E5JpDXiVpKIq3HAc6QClES6HuwZOv8nwHNYzCFVBmCF02FqZhn9Unr6jhlAeTRZe0GNI7V9mQNXeaeiDD+N7QvQBlosB8lwAs1lIJj6wAQrMMDGVJIVslL0BadZuY896yE/vIyQktP2Y3QV155xVpaWmzz5s02b94827Bhgy1cuNAOHDhgkyZN8o1/77337Ic//KGtW7fOrrrqKnv55Zft2muvtQ8++MBmzZpF39fbthMVIbFFQ6gwia6YDyNMTwI1QYOGxkutE8rCitADYDU7cBnDNaMfov7vZQDI7MBwfMAqd7YICcEWJuHwOZCUIqWSQlWpg2cja3CyDwIUaqsDln3W45KrzwADDBimCXCtHPKYIk81sLaC6pqa8T3r2RA9WsccSHMKozEKxe89PzNcCU/eABAmYQjqDKMQPdJyZjcfshI+VESMLAryaYCSBifE+zfFChSPIvGco9UivnqOGJ6yG6GPPfaY3XrrrXbzzTebmdnmzZtt69at9txzz9nKlSt943/zm9/Y5ZdfbnfffbeZma1Zs8Z27Nhhjz/+uG3evJm+75h43sZ8xfDEOT6cdhtb7YiMVZhrA99EufxPWoQ+TLiGOS8BcjPZVi6IMB5YEtzhiOsdj6vZuXA87gIzshGKc+v85/WAnwQ0ONl+7cjDifI6SQ8nzP8MYXCy4+pQiB6AvJzeBwvt7QDX8np8zHB1PJZZClbNPxwoRJ+uB0L3Je6TDmWbgAcW/eaDSjQh2LQsLNHEOSOgcgp61qDnRZivnZVtohVWCC8n+uJB6B3le/o8oZXgUcwHEJ8vv+1c0ZS1Oj6Tydju3butubl56Fg8Hrfm5mbbtWsXPGfXrl0F483MFi5cOOz4gYEBO3bsWMF/QgghhBCivJTVE3rkyBHL5XI2efLkguOTJ0+2jz76CJ7T2dkJx3d2dsLx69ats1/+8pe+418c7S/wXrmU/+0MJeknwasySpavB2GoRAyEgGGoHL3u91PHHCtMH7RnMO199LeNi9WP5U5FXhbUvznd4DuUN/93gryZqNd7Jt9HjRvM+++BKtW7s37vRjfwGPYAT2APEH/v83yMo2Cpu8F5/wU/sZ4MJy7f1+P/zfb2+qMBGeB9RHmiA/1AUxd4JOtAeBt5QuoGQSERuB7yVOaQFicqTCJiudBzCdYVXckh1x243iDYaBIp4JHLg/xk5LkD4xwIecZIofI42FOywHXZBz4uiPhbPfpsRJM0lHrvgFc1D5QF6kFeq4G9PAtycdMJ/+evi/uvh7yodeAe9eBc2MZ+oMd/jH02wHNBy88+8Pzp9Y9z/X0jjrEBkFPfDa7VXbhxdfd/+e9y5obmMr2WLdITmsv6nyvif5Q9HD/a/PznP7eWlpahf//73/+2c845x5rPv6eMsxJCCCFEsXR3d9v48eMjvWcymbTGxkb7vzdXBDq/sbHRkklOPrDWKKsROnHiREskEtbV1VVwvKuryxobG+E5jY2NRY1PpVKWSqWG/v31r3/dPv30U2toaLDu7m6bNm2affrppzZu3LiQn0YE5dixY1qHCkDrUBloHSoDrUNlcHwdDh8+bLFYzKZOnRr5HNLptHV0dFgmE6zDVDKZtHQ6XeJZVQdlNUKTyaTNnj3b2tvb7dprrzUzs3w+b+3t7bZ8+XJ4TlNTk7W3t9uKFSuGju3YscOampqoe8bjcfvmN79pZv+Tuxg3bpw2mQpA61AZaB0qA61DZaB1qAzGjx9f1nVIp9MyJEeBsofjW1pabOnSpTZnzhybO3eubdiwwXp6eoaq5X/0ox/ZaaedZuvWrTMzszvuuMMWLFhgjz76qF155ZXW1tZmf/vb3+ypp54q58cQQgghhBBFUHYjdPHixfb555/bL37xC+vs7LQLLrjAtm/fPlR8dPjwYYt/RcZn/vz59vLLL9u9995rq1atsjPOOMNef/31ojRChRBCCCFEeSm7EWpmtnz58mHD7zt37vQdu+666+y6664Lfd9UKmWtra0FOaMierQOlYHWoTLQOlQGWofKQOtQ3cRcJfTCEkIIIYQQNUVZxeqFEEIIIURtIiNUCCGEEEJEjoxQIYQQQggROVVvhG7atMmmT59u6XTa5s2bZ3/9619POP7VV1+1s846y9LptJ133nm2bdu2iGZa3RSzDk8//bRdcskldsopp9gpp5xizc3NI66b4Cj27+E4bW1tFovFhvR8RTiKXYcvvvjCli1bZlOmTLFUKmVnnnmm9qYSUOw6bNiwwb7zne/YmDFjbNq0aXbnnXdafz9oaSko3n77bbv66qtt6tSpFovF7PXXXx/xnJ07d9qFF15oqVTKTj/9dHvhhRdGfZ5iFHFVTFtbm0smk+65555zf//7392tt97qJkyY4Lq6uuD4d9991yUSCffQQw+5ffv2uXvvvdfV19e7Dz/8MOKZVxfFrsMNN9zgNm3a5Pbs2eP279/vbrrpJjd+/Hj3r3/9K+KZVxfFrsNxOjo63GmnneYuueQS9/3vfz+ayVYxxa7DwMCAmzNnjlu0aJF75513XEdHh9u5c6fbu3dvxDOvLopdh5deesmlUin30ksvuY6ODvfGG2+4KVOmuDvvvDPimVcP27Ztc6tXr3avvfaaMzO3ZcuWE44/ePCgGzt2rGtpaXH79u1zGzdudIlEwm3fvj2aCYuSU9VG6Ny5c92yZcuG/p3L5dzUqVPdunXr4Pjrr7/eXXnllQXH5s2b537yk5+M6jyrnWLXwUs2m3UNDQ3uxRdfHK0p1gRB1iGbzbr58+e7Z555xi1dulRGaAkodh2efPJJN2PGDJfJZKKaYk1Q7DosW7bMXXbZZQXHWlpa3EUXXTSq86wVGCP0nnvuceeee27BscWLF7uFCxeO4szEaFK14fhMJmO7d++25ubmoWPxeNyam5tt165d8Jxdu3YVjDczW7hw4bDjxcgEWQcvvb29Njg4aN/4xjdGa5pVT9B1eOCBB2zSpEl2yy23RDHNqifIOvzhD3+wpqYmW7ZsmU2ePNlmzZpla9eutVwuF9W0q44g6zB//nzbvXv3UMj+4MGDtm3bNlu0aFEkcxZ6RlcjFSFWPxocOXLEcrncUOel40yePNk++ugjeE5nZycc39nZOWrzrHaCrIOXn/3sZzZ16lTf5iN4gqzDO++8Y88++6zt3bs3ghnWBkHW4eDBg/bnP//ZbrzxRtu2bZt9/PHHdvvtt9vg4KC1trZGMe2qI8g63HDDDXbkyBG7+OKLzTln2WzWfvrTn9qqVauimLKw4Z/Rx44ds76+PhszZkyZZiaCUrWeUFEdrF+/3tra2mzLli2WTqfLPZ2aobu725YsWWJPP/20TZw4sdzTqWny+bxNmjTJnnrqKZs9e7YtXrzYVq9ebZs3by731GqKnTt32tq1a+2JJ56wDz74wF577TXbunWrrVmzptxTE+KkpWo9oRMnTrREImFdXV0Fx7u6uqyxsRGe09jYWNR4MTJB1uE4jzzyiK1fv97+9Kc/2fnnnz+a06x6il2Hf/7zn3bo0CG7+uqrh47l83kzM6urq7MDBw7YzJkzR3fSVUiQv4cpU6ZYfX29JRKJoWNnn322dXZ2WiaTsWQyOapzrkaCrMN9991nS5YssR//+MdmZnbeeedZT0+P3XbbbbZ69WqLx+XTGW2Ge0aPGzdOXtCTlKr9q0kmkzZ79mxrb28fOpbP5629vd2amprgOU1NTQXjzcx27Ngx7HgxMkHWwczsoYcesjVr1tj27dttzpw5UUy1qil2Hc466yz78MMPbe/evUP/XXPNNXbppZfa3r17bdq0aVFOv2oI8vdw0UUX2ccffzz0EmBm9o9//MOmTJkiAzQgQdaht7fXZ2gefzFw6n4dCXpGVyHlrowaTdra2lwqlXIvvPCC27dvn7vtttvchAkTXGdnp3POuSVLlriVK1cOjX/33XddXV2de+SRR9z+/ftda2urJJpKQLHrsH79epdMJt3vf/9799lnnw39193dXa6PUBUUuw5eVB1fGopdh8OHD7uGhga3fPlyd+DAAffHP/7RTZo0yf3qV78q10eoCopdh9bWVtfQ0OB++9vfuoMHD7o333zTzZw5011//fXl+ggnPd3d3W7Pnj1uz549zszcY4895vbs2eM++eQT55xzK1eudEuWLBkaf1yi6e6773b79+93mzZtkkTTSU5VG6HOObdx40b3rW99yyWTSTd37lz3/vvvD/2/BQsWuKVLlxaM/93vfufOPPNMl0wm3bnnnuu2bt0a8Yyrk2LW4dvf/rYzM99/ra2t0U+8yij27+GryAgtHcWuw3vvvefmzZvnUqmUmzFjhvv1r3/tstlsxLOuPopZh8HBQXf//fe7mTNnunQ67aZNm+Zuv/1295///Cf6iVcJb731Ftzrj3/vS5cudQsWLPCdc8EFF7hkMulmzJjhnn/++cjnLUpHzDnFEYQQQgghRLRUbU6oEEIIIYSoXGSECiGEEEKIyJERKoQQQgghIkdGqBBCCCGEiBwZoUIIIYQQInJkhAohhBBCiMiRESqEEEIIISJHRqgQQgghhIgcGaFCCCGEECJyZIQKIaqC733ve7ZixYpyT0MIIQSJjFAhhBBCCBE56h0vhDjpuemmm+zFF18sONbR0WHTp08vz4SEEEKMiIxQIcRJz9GjR+2KK66wWbNm2QMPPGBmZqeeeqolEokyz0wIIcRw1JV7AkIIEZbx48dbMpm0sWPHWmNjY7mnI4QQgkA5oUIIIYQQInJkhAohhBBCiMiRESqEqAqSyaTlcrlyT0MIIQSJjFAhRFUwffp0+8tf/mKHDh2yI0eOWD6fL/eUhBBCnAAZoUKIquCuu+6yRCJh55xzjp166ql2+PDhck9JCCHECZBEkxBCCCGEiBx5QoUQQgghROTICBVCCCGEEJEjI1QIIYQQQkSOjFAhhBBCCBE5MkKFEEIIIUTkyAgVQgghhBCRIyNUCCGEEEJEjoxQIYQQQggROTJChRBCCCFE5MgIFUIIIYQQkSMjVAghhBBCRI6MUCGEEEIIETn/D7448ZeoVYPaAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "output = model(jnp.array(coords))\n", "diff = output - u(coords[:,0], coords[:,1], 0.25).reshape(-1, 1)\n", "resplot = np.array(diff).reshape(N_t, N_x)\n", "\n", "plt.figure(figsize=(7, 4))\n", "plt.pcolormesh(T, X, resplot, shading='auto', cmap='Spectral_r')\n", "plt.colorbar()\n", "\n", "plt.title('Difference between approximated and analytical solution')\n", "plt.xlabel('t')\n", "\n", "plt.ylabel('x')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "349b88bc-5bc9-44f5-b255-19e25c5384d6", "metadata": {}, "source": [ "It appears that the approximation is good, since the maximum absolute error is $\\sim 0.02$.\n", "\n", "Finally, we can see that $\\tau$ is approximated quite well:" ] }, { "cell_type": "code", "execution_count": 11, "id": "4ebf96bf-59f6-48cc-b40f-821cb1d831aa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The approximated value for τ is 0.24094036221504211.\n" ] } ], "source": [ "print(f\"The approximated value for τ is {model.tau[0]}.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0e936047-8251-4b91-b04e-b780643a27dc", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }